File size: 2,513 Bytes
f07bfd7
933ca80
 
 
 
 
 
 
 
f07bfd7
933ca80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f07bfd7
 
 
 
933ca80
 
 
 
 
 
 
 
 
 
 
f07bfd7
 
 
 
933ca80
 
 
f07bfd7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# imports
import torch

from ..collator_for_classification import DataCollatorForGeneClassification

"""
Geneformer collator for multi-task cell classification.
"""


class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
    class_type = "cell"

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def _prepare_batch(self, features):
        # Process inputs as usual
        batch = self.tokenizer.pad(
            features,
            class_type=self.class_type,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        # Check if labels are present
        if "label" in features[0]:
            # Initialize labels dictionary for all tasks
            labels = {task: [] for task in features[0]["label"].keys()}

            # Populate labels for each task
            for feature in features:
                for task, label in feature["label"].items():
                    labels[task].append(label)

            # Convert label lists to tensors, handling dictionaries appropriately
            for task in labels:
                if isinstance(labels[task][0], (list, torch.Tensor)):
                    dtype = torch.long
                    labels[task] = torch.tensor(labels[task], dtype=dtype)
                elif isinstance(labels[task][0], dict):
                    # Handle dict specifically if needed
                    pass  # Resolve nested data structure

            # Update the batch to include task-specific labels
            batch["labels"] = labels
        else:
            # If no labels are present, create empty labels for all tasks
            batch["labels"] = {
                task: torch.tensor([], dtype=torch.long)
                for task in features[0]["input_ids"].keys()
            }

        return batch

    def __call__(self, features):
        batch = self._prepare_batch(features)

        for k, v in batch.items():
            if torch.is_tensor(v):
                batch[k] = v.clone().detach()
            elif isinstance(v, dict):
                # Assuming nested structure needs conversion
                batch[k] = {
                    task: torch.tensor(labels, dtype=torch.int64)
                    for task, labels in v.items()
                }
            else:
                batch[k] = torch.tensor(v, dtype=torch.int64)

        return batch