do not use clip() for distillation
Browse files
README.md
CHANGED
@@ -19,7 +19,7 @@ tags:
|
|
19 |
# Arousal - Dominance - Valence
|
20 |
|
21 |
Dimensional Speech Emotion Recognition model of simultaneous use of [WavLM](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) / [Wav2Vec2.0](https://hf.rst.im/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim).
|
22 |
-
Achieves `0.6760566` valence CCC on [MSP Podcast Test 1](https://paperswithcode.com/sota/speech-emotion-recognition-on-msp-podcast). Used as teacher for [
|
23 |
|
24 |
|
25 |
|
@@ -48,9 +48,8 @@ import torch
|
|
48 |
import types
|
49 |
import torch.nn as nn
|
50 |
from transformers import AutoModelForAudioClassification
|
51 |
-
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
52 |
-
|
53 |
-
Wav2Vec2PreTrainedModel)
|
54 |
|
55 |
|
56 |
signal = torch.from_numpy(
|
@@ -85,22 +84,19 @@ class Dawn(Wav2Vec2PreTrainedModel):
|
|
85 |
self.classifier = ADV(config)
|
86 |
|
87 |
def forward(self, x):
|
88 |
-
'''x: (batch, audio-samples-16KHz)'''
|
89 |
x -= x.mean(1, keepdim=True)
|
90 |
variance = (x * x).mean(1, keepdim=True) + 1e-7
|
91 |
-
x = self.wav2vec2(x / variance.sqrt()
|
92 |
-
|
93 |
-
return self.classifier(x.mean(1))
|
94 |
|
95 |
|
96 |
-
def
|
97 |
'''x: (batch, audio-samples-16KHz)'''
|
98 |
-
x = (x + self.config.mean) / self.config.std #
|
99 |
x = self.ssl_model(x, attention_mask=None).last_hidden_state
|
100 |
# pool
|
101 |
h = self.pool_model.sap_linear(x).tanh()
|
102 |
-
w = torch.matmul(h, self.pool_model.attention)
|
103 |
-
w = w.softmax(1)
|
104 |
mu = (x * w).sum(1)
|
105 |
x = torch.cat(
|
106 |
[
|
@@ -115,7 +111,7 @@ def _infer(self, x):
|
|
115 |
base = AutoModelForAudioClassification.from_pretrained(
|
116 |
'3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
|
117 |
trust_remote_code=True).to(device).eval()
|
118 |
-
base.forward = types.MethodType(
|
119 |
|
120 |
# Wav2Vec2
|
121 |
|
@@ -128,6 +124,7 @@ def wav2small(x):
|
|
128 |
return .5 * dawn(x) + .5 * base(x)
|
129 |
|
130 |
pred = wav2small(signal.to(device))
|
131 |
-
print(f'
|
132 |
-
f'
|
|
|
133 |
```
|
|
|
19 |
# Arousal - Dominance - Valence
|
20 |
|
21 |
Dimensional Speech Emotion Recognition model of simultaneous use of [WavLM](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) / [Wav2Vec2.0](https://hf.rst.im/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim).
|
22 |
+
Achieves `0.6760566` valence CCC on [MSP Podcast Test 1](https://paperswithcode.com/sota/speech-emotion-recognition-on-msp-podcast). Used as teacher for [Wav2Small](https://arxiv.org/abs/2408.13920).
|
23 |
|
24 |
|
25 |
|
|
|
48 |
import types
|
49 |
import torch.nn as nn
|
50 |
from transformers import AutoModelForAudioClassification
|
51 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model,
|
52 |
+
Wav2Vec2PreTrainedModel)
|
|
|
53 |
|
54 |
|
55 |
signal = torch.from_numpy(
|
|
|
84 |
self.classifier = ADV(config)
|
85 |
|
86 |
def forward(self, x):
|
|
|
87 |
x -= x.mean(1, keepdim=True)
|
88 |
variance = (x * x).mean(1, keepdim=True) + 1e-7
|
89 |
+
x = self.wav2vec2(x / variance.sqrt())
|
90 |
+
return self.classifier(x.last_hidden_state.mean(1))
|
|
|
91 |
|
92 |
|
93 |
+
def _forward(self, x):
|
94 |
'''x: (batch, audio-samples-16KHz)'''
|
95 |
+
x = (x + self.config.mean) / self.config.std # sgn
|
96 |
x = self.ssl_model(x, attention_mask=None).last_hidden_state
|
97 |
# pool
|
98 |
h = self.pool_model.sap_linear(x).tanh()
|
99 |
+
w = torch.matmul(h, self.pool_model.attention).softmax(1)
|
|
|
100 |
mu = (x * w).sum(1)
|
101 |
x = torch.cat(
|
102 |
[
|
|
|
111 |
base = AutoModelForAudioClassification.from_pretrained(
|
112 |
'3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
|
113 |
trust_remote_code=True).to(device).eval()
|
114 |
+
base.forward = types.MethodType(_forward, base)
|
115 |
|
116 |
# Wav2Vec2
|
117 |
|
|
|
124 |
return .5 * dawn(x) + .5 * base(x)
|
125 |
|
126 |
pred = wav2small(signal.to(device))
|
127 |
+
print(f'Arousal={pred[0, 0]} '
|
128 |
+
f'Dominance={pred[0, 1]} ',
|
129 |
+
f'Valence={pred[0, 2]}')
|
130 |
```
|