File size: 2,513 Bytes
f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# imports
import torch
from ..collator_for_classification import DataCollatorForGeneClassification
"""
Geneformer collator for multi-task cell classification.
"""
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
class_type = "cell"
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def _prepare_batch(self, features):
# Process inputs as usual
batch = self.tokenizer.pad(
features,
class_type=self.class_type,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
# Check if labels are present
if "label" in features[0]:
# Initialize labels dictionary for all tasks
labels = {task: [] for task in features[0]["label"].keys()}
# Populate labels for each task
for feature in features:
for task, label in feature["label"].items():
labels[task].append(label)
# Convert label lists to tensors, handling dictionaries appropriately
for task in labels:
if isinstance(labels[task][0], (list, torch.Tensor)):
dtype = torch.long
labels[task] = torch.tensor(labels[task], dtype=dtype)
elif isinstance(labels[task][0], dict):
# Handle dict specifically if needed
pass # Resolve nested data structure
# Update the batch to include task-specific labels
batch["labels"] = labels
else:
# If no labels are present, create empty labels for all tasks
batch["labels"] = {
task: torch.tensor([], dtype=torch.long)
for task in features[0]["input_ids"].keys()
}
return batch
def __call__(self, features):
batch = self._prepare_batch(features)
for k, v in batch.items():
if torch.is_tensor(v):
batch[k] = v.clone().detach()
elif isinstance(v, dict):
# Assuming nested structure needs conversion
batch[k] = {
task: torch.tensor(labels, dtype=torch.int64)
for task, labels in v.items()
}
else:
batch[k] = torch.tensor(v, dtype=torch.int64)
return batch
|