Ar4ikov commited on
Commit
3113179
·
1 Parent(s): d2c3409

Create wav2vec2speechclassification.py

Browse files
Files changed (1) hide show
  1. wav2vec2speechclassification.py +127 -0
wav2vec2speechclassification.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+ import torch
4
+ from transformers.file_utils import ModelOutput
5
+
6
+
7
+ @dataclass
8
+ class SpeechClassifierOutput(ModelOutput):
9
+ loss: Optional[torch.FloatTensor] = None
10
+ logits: torch.FloatTensor = None
11
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
12
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
18
+
19
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
20
+ Wav2Vec2PreTrainedModel,
21
+ Wav2Vec2Model
22
+ )
23
+
24
+
25
+ class Wav2Vec2ClassificationHead(nn.Module):
26
+ """Head for wav2vec classification task."""
27
+
28
+ def __init__(self, config):
29
+ super().__init__()
30
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
31
+ self.dropout = nn.Dropout(config.final_dropout)
32
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
33
+
34
+ def forward(self, features, **kwargs):
35
+ x = features
36
+ x = self.dropout(x)
37
+ x = self.dense(x)
38
+ x = torch.tanh(x)
39
+ x = self.dropout(x)
40
+ x = self.out_proj(x)
41
+ return x
42
+
43
+
44
+ class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
45
+ def __init__(self, config):
46
+ super().__init__(config)
47
+ self.num_labels = config.num_labels
48
+ self.pooling_mode = config.pooling_mode
49
+ self.config = config
50
+
51
+ self.wav2vec2 = Wav2Vec2Model(config)
52
+ self.classifier = Wav2Vec2ClassificationHead(config)
53
+
54
+ self.init_weights()
55
+
56
+ def freeze_feature_extractor(self):
57
+ self.wav2vec2.feature_extractor._freeze_parameters()
58
+
59
+ def merged_strategy(
60
+ self,
61
+ hidden_states,
62
+ mode="mean"
63
+ ):
64
+ if mode == "mean":
65
+ outputs = torch.mean(hidden_states, dim=1)
66
+ elif mode == "sum":
67
+ outputs = torch.sum(hidden_states, dim=1)
68
+ elif mode == "max":
69
+ outputs = torch.max(hidden_states, dim=1)[0]
70
+ else:
71
+ raise Exception(
72
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
73
+
74
+ return outputs
75
+
76
+ def forward(
77
+ self,
78
+ input_values,
79
+ attention_mask=None,
80
+ output_attentions=None,
81
+ output_hidden_states=None,
82
+ return_dict=None,
83
+ labels=None,
84
+ ):
85
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
+ outputs = self.wav2vec2(
87
+ input_values,
88
+ attention_mask=attention_mask,
89
+ output_attentions=output_attentions,
90
+ output_hidden_states=output_hidden_states,
91
+ return_dict=return_dict,
92
+ )
93
+ hidden_states = outputs[0]
94
+ hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
95
+ logits = self.classifier(hidden_states)
96
+
97
+ loss = None
98
+ if labels is not None:
99
+ if self.config.problem_type is None:
100
+ if self.num_labels == 1:
101
+ self.config.problem_type = "regression"
102
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
103
+ self.config.problem_type = "single_label_classification"
104
+ else:
105
+ self.config.problem_type = "multi_label_classification"
106
+
107
+ if self.config.problem_type == "regression":
108
+ loss_fct = MSELoss()
109
+ loss = loss_fct(logits.view(-1, self.num_labels), labels)
110
+ elif self.config.problem_type == "single_label_classification":
111
+ loss_fct = CrossEntropyLoss()
112
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
113
+ elif self.config.problem_type == "multi_label_classification":
114
+ loss_fct = BCEWithLogitsLoss()
115
+ loss = loss_fct(logits, labels)
116
+
117
+ if not return_dict:
118
+ output = (logits,) + outputs[2:]
119
+ return ((loss,) + output) if loss is not None else output
120
+
121
+ return SpeechClassifierOutput(
122
+ loss=loss,
123
+ logits=logits,
124
+ hidden_states=outputs.hidden_states,
125
+ attentions=outputs.attentions,
126
+ )
127
+