litagin's picture
init
3e40110
raw
history blame
5.01 kB
from typing import Optional
import torch
import torch.nn as nn
from transformers.activations import get_activation
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
_HIDDEN_STATES_START_POSITION = 2
class ClassificationHead(nn.Module):
def __init__(self, config):
super().__init__()
print(f"classifier_proj_size: {config.classifier_proj_size}")
self.dense = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.layer_norm = nn.LayerNorm(config.classifier_proj_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.classifier_proj_size, config.num_labels)
print(f"Head activation: {config.head_activation}")
self.activation = get_activation(config.head_activation)
def forward(self, features, **kwargs):
x = features
x = self.dense(x)
x = self.layer_norm(x)
x = self.activation(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class EmotionModel(Wav2Vec2PreTrainedModel):
"""Speech emotion classifier."""
def __init__(self, config, counts: Optional[dict[int, int]] = None):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = ClassificationHead(config)
num_layers = (
config.num_hidden_layers + 1
) # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.init_weights()
# counts が与えられている場合、クラスの重みを計算
if counts is not None:
print(f"Using class weights: {counts}")
counts_list = [counts[i] for i in range(config.num_labels)]
counts_tensor = torch.tensor(
counts_list, dtype=torch.float, device="cuda:0"
)
total_samples = counts_tensor.sum()
class_weights = total_samples / (config.num_labels * counts_tensor)
# 重みを正規化(任意)
class_weights = class_weights / class_weights.sum() * config.num_labels
self.class_weights = class_weights
else:
self.class_weights = None # counts がない場合は None に設定
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
):
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
output_hidden_states = (
True if self.config.use_weighted_layer_sum else output_hidden_states
)
# print(f"output_hidden_states: {output_hidden_states}")
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(
hidden_states.shape[1], attention_mask
)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(
-1, 1
)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
# CrossEntropyLoss に重みを適用(class_weights が None でも機能する)
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def freeze_base_model(self):
r"""Freeze base model."""
for param in self.wav2vec2.parameters():
param.requires_grad = False
def freeze_feature_encoder(self):
r"""Freeze feature extractor."""
self.wav2vec2.freeze_feature_encoder()