File size: 4,423 Bytes
933ca80
 
f07bfd7
 
933ca80
 
 
f07bfd7
933ca80
 
 
f07bfd7
 
 
933ca80
 
 
 
 
 
 
f07bfd7
933ca80
f07bfd7
 
 
 
 
 
 
 
 
 
933ca80
 
 
 
 
 
 
 
f07bfd7
 
 
 
 
 
 
 
 
933ca80
 
 
 
 
 
f07bfd7
 
 
933ca80
f07bfd7
 
 
 
 
 
933ca80
 
 
 
 
 
 
 
 
 
 
 
 
 
f07bfd7
 
 
 
 
933ca80
 
 
 
 
 
 
 
f07bfd7
 
 
933ca80
 
 
f07bfd7
 
 
933ca80
 
 
 
 
 
f07bfd7
 
 
933ca80
 
 
 
 
f07bfd7
 
 
933ca80
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel


class AttentionPool(nn.Module):
    """Attention-based pooling layer."""

    def __init__(self, hidden_size):
        super(AttentionPool, self).__init__()
        self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
        nn.init.xavier_uniform_(
            self.attention_weights
        )  # https://pytorch.org/docs/stable/nn.init.html

    def forward(self, hidden_states):
        attention_scores = torch.matmul(hidden_states, self.attention_weights)
        attention_scores = torch.softmax(attention_scores, dim=1)
        pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
        return pooled_output


class GeneformerMultiTask(nn.Module):
    def __init__(
        self,
        pretrained_path,
        num_labels_list,
        dropout_rate=0.1,
        use_task_weights=False,
        task_weights=None,
        max_layers_to_freeze=0,
        use_attention_pooling=False,
    ):
        super(GeneformerMultiTask, self).__init__()
        self.config = BertConfig.from_pretrained(pretrained_path)
        self.bert = BertModel(self.config)
        self.num_labels_list = num_labels_list
        self.use_task_weights = use_task_weights
        self.dropout = nn.Dropout(dropout_rate)
        self.use_attention_pooling = use_attention_pooling

        if use_task_weights and (
            task_weights is None or len(task_weights) != len(num_labels_list)
        ):
            raise ValueError(
                "Task weights must be defined and match the number of tasks when 'use_task_weights' is True."
            )
        self.task_weights = (
            task_weights if use_task_weights else [1.0] * len(num_labels_list)
        )

        # Freeze the specified initial layers
        for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
            for param in layer.parameters():
                param.requires_grad = False

        self.attention_pool = (
            AttentionPool(self.config.hidden_size) if use_attention_pooling else None
        )

        self.classification_heads = nn.ModuleList(
            [
                nn.Linear(self.config.hidden_size, num_labels)
                for num_labels in num_labels_list
            ]
        )
        # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
        for head in self.classification_heads:
            nn.init.xavier_uniform_(head.weight)
            nn.init.zeros_(head.bias)

    def forward(self, input_ids, attention_mask, labels=None):
        try:
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        except Exception as e:
            raise RuntimeError(f"Error during BERT forward pass: {e}")

        sequence_output = outputs.last_hidden_state

        try:
            pooled_output = (
                self.attention_pool(sequence_output)
                if self.use_attention_pooling
                else sequence_output[:, 0, :]
            )
            pooled_output = self.dropout(pooled_output)
        except Exception as e:
            raise RuntimeError(f"Error during pooling and dropout: {e}")

        total_loss = 0
        logits = []
        losses = []

        for task_id, (head, num_labels) in enumerate(
            zip(self.classification_heads, self.num_labels_list)
        ):
            try:
                task_logits = head(pooled_output)
            except Exception as e:
                raise RuntimeError(
                    f"Error during forward pass of classification head {task_id}: {e}"
                )

            logits.append(task_logits)

            if labels is not None:
                try:
                    loss_fct = nn.CrossEntropyLoss()
                    task_loss = loss_fct(
                        task_logits.view(-1, num_labels), labels[task_id].view(-1)
                    )
                    if self.use_task_weights:
                        task_loss *= self.task_weights[task_id]
                    total_loss += task_loss
                    losses.append(task_loss.item())
                except Exception as e:
                    raise RuntimeError(
                        f"Error during loss computation for task {task_id}: {e}"
                    )

        return total_loss, logits, losses if labels is not None else logits