File size: 5,005 Bytes
3e40110 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Optional
import torch
import torch.nn as nn
from transformers.activations import get_activation
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
_HIDDEN_STATES_START_POSITION = 2
class ClassificationHead(nn.Module):
def __init__(self, config):
super().__init__()
print(f"classifier_proj_size: {config.classifier_proj_size}")
self.dense = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.layer_norm = nn.LayerNorm(config.classifier_proj_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.classifier_proj_size, config.num_labels)
print(f"Head activation: {config.head_activation}")
self.activation = get_activation(config.head_activation)
def forward(self, features, **kwargs):
x = features
x = self.dense(x)
x = self.layer_norm(x)
x = self.activation(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class EmotionModel(Wav2Vec2PreTrainedModel):
"""Speech emotion classifier."""
def __init__(self, config, counts: Optional[dict[int, int]] = None):
super().__init__(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = ClassificationHead(config)
num_layers = (
config.num_hidden_layers + 1
) # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.init_weights()
# counts が与えられている場合、クラスの重みを計算
if counts is not None:
print(f"Using class weights: {counts}")
counts_list = [counts[i] for i in range(config.num_labels)]
counts_tensor = torch.tensor(
counts_list, dtype=torch.float, device="cuda:0"
)
total_samples = counts_tensor.sum()
class_weights = total_samples / (config.num_labels * counts_tensor)
# 重みを正規化(任意)
class_weights = class_weights / class_weights.sum() * config.num_labels
self.class_weights = class_weights
else:
self.class_weights = None # counts がない場合は None に設定
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
):
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
output_hidden_states = (
True if self.config.use_weighted_layer_sum else output_hidden_states
)
# print(f"output_hidden_states: {output_hidden_states}")
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(
hidden_states.shape[1], attention_mask
)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(
-1, 1
)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
# CrossEntropyLoss に重みを適用(class_weights が None でも機能する)
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def freeze_base_model(self):
r"""Freeze base model."""
for param in self.wav2vec2.parameters():
param.requires_grad = False
def freeze_feature_encoder(self):
r"""Freeze feature extractor."""
self.wav2vec2.freeze_feature_encoder()
|