File size: 2,370 Bytes
6388076 9c243c1 7f88877 9c243c1 6388076 9c243c1 6388076 9c243c1 6388076 9c243c1 6388076 9c243c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from transformers import PretrainedConfig, XLMRobertaForSequenceClassification
import torch.nn as nn
import torch
class CustomConfig(PretrainedConfig):
model_type = "custom_model"
def __init__(self, num_emotion_labels=18, **kwargs):
super().__init__(**kwargs)
self.num_emotion_labels = num_emotion_labels
class CustomModel(XLMRobertaForSequenceClassification):
config_class = CustomConfig
def __init__(self, config):
super(CustomModel, self).__init__(config)
self.num_emotion_labels = config.num_emotion_labels
self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob)
self.emotion_classifier = nn.Sequential(
nn.Linear(config.hidden_size, 512),
nn.Mish(),
nn.Dropout(0.3),
nn.Linear(512, self.num_emotion_labels)
)
self._init_weights(self.emotion_classifier[0])
self._init_weights(self.emotion_classifier[3])
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
if len(sequence_output.shape) != 3:
raise ValueError(f"Expected sequence_output to have 3 dimensions, got {sequence_output.shape}")
cls_hidden_states = sequence_output[:, 0, :]
cls_hidden_states = self.dropout_emotion(cls_hidden_states)
emotion_logits = self.emotion_classifier(cls_hidden_states)
with torch.no_grad():
cls_token_state = sequence_output[:, 0, :].unsqueeze(1)
sentiment_logits = self.classifier(cls_token_state).squeeze(1)
if labels is not None:
class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device)
loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
loss = loss_fct(emotion_logits, labels)
return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits}
|