Commit
•
7470753
1
Parent(s):
beb62a4
pointing dictionaries from the mtl module's init (#397)
Browse files- pointing dictionaries from the mtl module's init (5539d14469f84e2f0b13a7cb3f6054b2b0cbf1f3)
Co-authored-by: Madhavan Venkatesh <[email protected]>
geneformer/mtl/collators.py
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
# imports
|
2 |
import torch
|
3 |
-
|
4 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
|
|
5 |
|
6 |
"""
|
7 |
Geneformer collator for multi-task cell classification.
|
8 |
"""
|
9 |
|
10 |
-
|
11 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
12 |
class_type = "cell"
|
13 |
|
14 |
def __init__(self, *args, **kwargs) -> None:
|
15 |
-
|
|
|
16 |
|
17 |
def _prepare_batch(self, features):
|
18 |
# Process inputs as usual
|
|
|
1 |
# imports
|
2 |
import torch
|
|
|
3 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
4 |
+
from . import TOKEN_DICTIONARY # import the token dictionary from the mtl module's init
|
5 |
|
6 |
"""
|
7 |
Geneformer collator for multi-task cell classification.
|
8 |
"""
|
9 |
|
|
|
10 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
11 |
class_type = "cell"
|
12 |
|
13 |
def __init__(self, *args, **kwargs) -> None:
|
14 |
+
# Use the loaded token dictionary from the mtl module's init
|
15 |
+
super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
|
16 |
|
17 |
def _prepare_batch(self, features):
|
18 |
# Process inputs as usual
|