Commit
•
85f295e
1
Parent(s):
7eca269
dictionaries from parent dir (#405)
Browse files- dictionaries form parent dir (713240f3833d3547237ba53fe2c7e09bfc04e5f0)
Co-authored-by: Madhavan Venkatesh <[email protected]>
- geneformer/mtl/collators.py +13 -10
geneformer/mtl/collators.py
CHANGED
@@ -1,18 +1,24 @@
|
|
1 |
# imports
|
2 |
import torch
|
|
|
3 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
4 |
-
from
|
5 |
|
6 |
-
"""
|
7 |
-
Geneformer collator for multi-task cell classification.
|
8 |
-
"""
|
9 |
|
10 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
11 |
class_type = "cell"
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
def __init__(self, *args, **kwargs) -> None:
|
14 |
-
#
|
15 |
-
|
|
|
|
|
16 |
|
17 |
def _prepare_batch(self, features):
|
18 |
# Process inputs as usual
|
@@ -29,7 +35,6 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
|
|
29 |
if "label" in features[0]:
|
30 |
# Initialize labels dictionary for all tasks
|
31 |
labels = {task: [] for task in features[0]["label"].keys()}
|
32 |
-
|
33 |
# Populate labels for each task
|
34 |
for feature in features:
|
35 |
for task, label in feature["label"].items():
|
@@ -57,7 +62,6 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
|
|
57 |
|
58 |
def __call__(self, features):
|
59 |
batch = self._prepare_batch(features)
|
60 |
-
|
61 |
for k, v in batch.items():
|
62 |
if torch.is_tensor(v):
|
63 |
batch[k] = v.clone().detach()
|
@@ -69,5 +73,4 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
|
|
69 |
}
|
70 |
else:
|
71 |
batch[k] = torch.tensor(v, dtype=torch.int64)
|
72 |
-
|
73 |
-
return batch
|
|
|
1 |
# imports
|
2 |
import torch
|
3 |
+
import pickle
|
4 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
5 |
+
from .. import TOKEN_DICTIONARY_FILE
|
6 |
|
7 |
+
"""Geneformer collator for multi-task cell classification."""
|
|
|
|
|
8 |
|
9 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
10 |
class_type = "cell"
|
11 |
|
12 |
+
@staticmethod
|
13 |
+
def load_token_dictionary():
|
14 |
+
with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
|
15 |
+
return pickle.load(f)
|
16 |
+
|
17 |
def __init__(self, *args, **kwargs) -> None:
|
18 |
+
# Load the token dictionary
|
19 |
+
token_dictionary = self.load_token_dictionary()
|
20 |
+
# Use the loaded token dictionary
|
21 |
+
super().__init__(token_dictionary=token_dictionary, *args, **kwargs)
|
22 |
|
23 |
def _prepare_batch(self, features):
|
24 |
# Process inputs as usual
|
|
|
35 |
if "label" in features[0]:
|
36 |
# Initialize labels dictionary for all tasks
|
37 |
labels = {task: [] for task in features[0]["label"].keys()}
|
|
|
38 |
# Populate labels for each task
|
39 |
for feature in features:
|
40 |
for task, label in feature["label"].items():
|
|
|
62 |
|
63 |
def __call__(self, features):
|
64 |
batch = self._prepare_batch(features)
|
|
|
65 |
for k, v in batch.items():
|
66 |
if torch.is_tensor(v):
|
67 |
batch[k] = v.clone().detach()
|
|
|
73 |
}
|
74 |
else:
|
75 |
batch[k] = torch.tensor(v, dtype=torch.int64)
|
76 |
+
return batch
|
|