|
from transformers import PretrainedConfig, XLMRobertaForSequenceClassification |
|
import torch.nn as nn |
|
import torch |
|
|
|
class CustomConfig(PretrainedConfig): |
|
model_type = "custom_model" |
|
|
|
def __init__(self, num_emotion_labels=18, **kwargs): |
|
super().__init__(**kwargs) |
|
self.num_emotion_labels = num_emotion_labels |
|
|
|
class CustomModel(XLMRobertaForSequenceClassification): |
|
config_class = CustomConfig |
|
|
|
def __init__(self, config): |
|
super(CustomModel, self).__init__(config) |
|
self.num_emotion_labels = config.num_emotion_labels |
|
self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob) |
|
self.emotion_classifier = nn.Sequential( |
|
nn.Linear(config.hidden_size, 512), |
|
nn.Mish(), |
|
nn.Dropout(0.3), |
|
nn.Linear(512, self.num_emotion_labels) |
|
) |
|
self._init_weights(self.emotion_classifier[0]) |
|
self._init_weights(self.emotion_classifier[3]) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None): |
|
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) |
|
sequence_output = outputs[0] |
|
if len(sequence_output.shape) != 3: |
|
raise ValueError(f"Expected sequence_output to have 3 dimensions, got {sequence_output.shape}") |
|
cls_hidden_states = sequence_output[:, 0, :] |
|
cls_hidden_states = self.dropout_emotion(cls_hidden_states) |
|
emotion_logits = self.emotion_classifier(cls_hidden_states) |
|
with torch.no_grad(): |
|
cls_token_state = sequence_output[:, 0, :].unsqueeze(1) |
|
sentiment_logits = self.classifier(cls_token_state).squeeze(1) |
|
if labels is not None: |
|
class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device) |
|
loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights) |
|
loss = loss_fct(emotion_logits, labels) |
|
return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits} |
|
return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits} |
|
|