File size: 5,005 Bytes
3e40110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import Optional

import torch
import torch.nn as nn
from transformers.activations import get_activation
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2Model,
    Wav2Vec2PreTrainedModel,
)

_HIDDEN_STATES_START_POSITION = 2


class ClassificationHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        print(f"classifier_proj_size: {config.classifier_proj_size}")
        self.dense = nn.Linear(config.hidden_size, config.classifier_proj_size)
        self.layer_norm = nn.LayerNorm(config.classifier_proj_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.classifier_proj_size, config.num_labels)
        print(f"Head activation: {config.head_activation}")
        self.activation = get_activation(config.head_activation)

    def forward(self, features, **kwargs):
        x = features
        x = self.dense(x)
        x = self.layer_norm(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class EmotionModel(Wav2Vec2PreTrainedModel):
    """Speech emotion classifier."""

    def __init__(self, config, counts: Optional[dict[int, int]] = None):
        super().__init__(config)

        self.config = config
        self.wav2vec2 = Wav2Vec2Model(config)
        self.classifier = ClassificationHead(config)
        num_layers = (
            config.num_hidden_layers + 1
        )  # transformer layers + input embeddings
        if config.use_weighted_layer_sum:
            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
        self.init_weights()

        # counts が与えられている場合、クラスの重みを計算
        if counts is not None:
            print(f"Using class weights: {counts}")
            counts_list = [counts[i] for i in range(config.num_labels)]
            counts_tensor = torch.tensor(
                counts_list, dtype=torch.float, device="cuda:0"
            )
            total_samples = counts_tensor.sum()
            class_weights = total_samples / (config.num_labels * counts_tensor)
            # 重みを正規化(任意)
            class_weights = class_weights / class_weights.sum() * config.num_labels
            self.class_weights = class_weights
        else:
            self.class_weights = None  # counts がない場合は None に設定

    def forward(
        self,
        input_values: Optional[torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
    ):
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        output_hidden_states = (
            True if self.config.use_weighted_layer_sum else output_hidden_states
        )
        # print(f"output_hidden_states: {output_hidden_states}")

        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if self.config.use_weighted_layer_sum:
            hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
            hidden_states = torch.stack(hidden_states, dim=1)
            norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
            hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
        else:
            hidden_states = outputs[0]

        if attention_mask is None:
            pooled_output = hidden_states.mean(dim=1)
        else:
            padding_mask = self._get_feature_vector_attention_mask(
                hidden_states.shape[1], attention_mask
            )
            hidden_states[~padding_mask] = 0.0
            pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(
                -1, 1
            )

        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            # CrossEntropyLoss に重みを適用(class_weights が None でも機能する)
            loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def freeze_base_model(self):
        r"""Freeze base model."""
        for param in self.wav2vec2.parameters():
            param.requires_grad = False

    def freeze_feature_encoder(self):
        r"""Freeze feature extractor."""
        self.wav2vec2.freeze_feature_encoder()