parent directory for dictionary

#404
Files changed (1) hide show
  1. geneformer/mtl/collators.py +11 -9
geneformer/mtl/collators.py CHANGED
@@ -1,17 +1,22 @@
1
  # imports
2
  import torch
 
3
  from ..collator_for_classification import DataCollatorForGeneClassification
4
- from . import TOKEN_DICTIONARY # import the token dictionary from the mtl module's init
5
 
6
- """
7
- Geneformer collator for multi-task cell classification.
8
- """
 
 
 
 
9
 
10
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
11
  class_type = "cell"
12
 
13
  def __init__(self, *args, **kwargs) -> None:
14
- # Use the loaded token dictionary from the mtl module's init
15
  super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
16
 
17
  def _prepare_batch(self, features):
@@ -29,7 +34,6 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
29
  if "label" in features[0]:
30
  # Initialize labels dictionary for all tasks
31
  labels = {task: [] for task in features[0]["label"].keys()}
32
-
33
  # Populate labels for each task
34
  for feature in features:
35
  for task, label in feature["label"].items():
@@ -57,7 +61,6 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
57
 
58
  def __call__(self, features):
59
  batch = self._prepare_batch(features)
60
-
61
  for k, v in batch.items():
62
  if torch.is_tensor(v):
63
  batch[k] = v.clone().detach()
@@ -69,5 +72,4 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
69
  }
70
  else:
71
  batch[k] = torch.tensor(v, dtype=torch.int64)
72
-
73
- return batch
 
1
  # imports
2
  import torch
3
+ import pickle
4
  from ..collator_for_classification import DataCollatorForGeneClassification
5
+ from .. import TOKEN_DICTIONARY_FILE
6
 
7
+ def load_token_dictionary():
8
+ with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
9
+ return pickle.load(f)
10
+
11
+ TOKEN_DICTIONARY = load_token_dictionary()
12
+
13
+ """Geneformer collator for multi-task cell classification."""
14
 
15
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
16
  class_type = "cell"
17
 
18
  def __init__(self, *args, **kwargs) -> None:
19
+ # Use the loaded token dictionary
20
  super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
21
 
22
  def _prepare_batch(self, features):
 
34
  if "label" in features[0]:
35
  # Initialize labels dictionary for all tasks
36
  labels = {task: [] for task in features[0]["label"].keys()}
 
37
  # Populate labels for each task
38
  for feature in features:
39
  for task, label in feature["label"].items():
 
61
 
62
  def __call__(self, features):
63
  batch = self._prepare_batch(features)
 
64
  for k, v in batch.items():
65
  if torch.is_tensor(v):
66
  batch[k] = v.clone().detach()
 
72
  }
73
  else:
74
  batch[k] = torch.tensor(v, dtype=torch.int64)
75
+ return batch