parent directory for dictionary
#404
by
madhavanvenkatesh
- opened
- geneformer/mtl/collators.py +11 -9
geneformer/mtl/collators.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1 |
# imports
|
2 |
import torch
|
|
|
3 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
4 |
-
from
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
11 |
class_type = "cell"
|
12 |
|
13 |
def __init__(self, *args, **kwargs) -> None:
|
14 |
-
# Use the loaded token dictionary
|
15 |
super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
|
16 |
|
17 |
def _prepare_batch(self, features):
|
@@ -29,7 +34,6 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
|
|
29 |
if "label" in features[0]:
|
30 |
# Initialize labels dictionary for all tasks
|
31 |
labels = {task: [] for task in features[0]["label"].keys()}
|
32 |
-
|
33 |
# Populate labels for each task
|
34 |
for feature in features:
|
35 |
for task, label in feature["label"].items():
|
@@ -57,7 +61,6 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
|
|
57 |
|
58 |
def __call__(self, features):
|
59 |
batch = self._prepare_batch(features)
|
60 |
-
|
61 |
for k, v in batch.items():
|
62 |
if torch.is_tensor(v):
|
63 |
batch[k] = v.clone().detach()
|
@@ -69,5 +72,4 @@ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassificati
|
|
69 |
}
|
70 |
else:
|
71 |
batch[k] = torch.tensor(v, dtype=torch.int64)
|
72 |
-
|
73 |
-
return batch
|
|
|
1 |
# imports
|
2 |
import torch
|
3 |
+
import pickle
|
4 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
5 |
+
from .. import TOKEN_DICTIONARY_FILE
|
6 |
|
7 |
+
def load_token_dictionary():
|
8 |
+
with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
|
9 |
+
return pickle.load(f)
|
10 |
+
|
11 |
+
TOKEN_DICTIONARY = load_token_dictionary()
|
12 |
+
|
13 |
+
"""Geneformer collator for multi-task cell classification."""
|
14 |
|
15 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
16 |
class_type = "cell"
|
17 |
|
18 |
def __init__(self, *args, **kwargs) -> None:
|
19 |
+
# Use the loaded token dictionary
|
20 |
super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
|
21 |
|
22 |
def _prepare_batch(self, features):
|
|
|
34 |
if "label" in features[0]:
|
35 |
# Initialize labels dictionary for all tasks
|
36 |
labels = {task: [] for task in features[0]["label"].keys()}
|
|
|
37 |
# Populate labels for each task
|
38 |
for feature in features:
|
39 |
for task, label in feature["label"].items():
|
|
|
61 |
|
62 |
def __call__(self, features):
|
63 |
batch = self._prepare_batch(features)
|
|
|
64 |
for k, v in batch.items():
|
65 |
if torch.is_tensor(v):
|
66 |
batch[k] = v.clone().detach()
|
|
|
72 |
}
|
73 |
else:
|
74 |
batch[k] = torch.tensor(v, dtype=torch.int64)
|
75 |
+
return batch
|
|