dkounadis commited on
Commit
9f61606
1 Parent(s): 6a3daac

do not use clip() for distillation

Browse files
Files changed (1) hide show
  1. README.md +12 -15
README.md CHANGED
@@ -19,7 +19,7 @@ tags:
19
  # Arousal - Dominance - Valence
20
 
21
  Dimensional Speech Emotion Recognition model of simultaneous use of [WavLM](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) / [Wav2Vec2.0](https://hf.rst.im/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim).
22
- Achieves `0.6760566` valence CCC on [MSP Podcast Test 1](https://paperswithcode.com/sota/speech-emotion-recognition-on-msp-podcast). Used as teacher for [wav2small](https://arxiv.org/abs/2408.13920).
23
 
24
 
25
 
@@ -48,9 +48,8 @@ import torch
48
  import types
49
  import torch.nn as nn
50
  from transformers import AutoModelForAudioClassification
51
- from transformers.models.wav2vec2.modeling_wav2vec2 import (
52
- Wav2Vec2Model,
53
- Wav2Vec2PreTrainedModel)
54
 
55
 
56
  signal = torch.from_numpy(
@@ -85,22 +84,19 @@ class Dawn(Wav2Vec2PreTrainedModel):
85
  self.classifier = ADV(config)
86
 
87
  def forward(self, x):
88
- '''x: (batch, audio-samples-16KHz)'''
89
  x -= x.mean(1, keepdim=True)
90
  variance = (x * x).mean(1, keepdim=True) + 1e-7
91
- x = self.wav2vec2(x / variance.sqrt()
92
- ).last_hidden_state
93
- return self.classifier(x.mean(1))
94
 
95
 
96
- def _infer(self, x):
97
  '''x: (batch, audio-samples-16KHz)'''
98
- x = (x + self.config.mean) / self.config.std # plus
99
  x = self.ssl_model(x, attention_mask=None).last_hidden_state
100
  # pool
101
  h = self.pool_model.sap_linear(x).tanh()
102
- w = torch.matmul(h, self.pool_model.attention)
103
- w = w.softmax(1)
104
  mu = (x * w).sum(1)
105
  x = torch.cat(
106
  [
@@ -115,7 +111,7 @@ def _infer(self, x):
115
  base = AutoModelForAudioClassification.from_pretrained(
116
  '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
117
  trust_remote_code=True).to(device).eval()
118
- base.forward = types.MethodType(_infer, base)
119
 
120
  # Wav2Vec2
121
 
@@ -128,6 +124,7 @@ def wav2small(x):
128
  return .5 * dawn(x) + .5 * base(x)
129
 
130
  pred = wav2small(signal.to(device))
131
- print(f'\nArousal = {pred[:, 0]} Dominance = {pred[:, 1]}',
132
- f' Valence = {pred[:, 2]}')
 
133
  ```
 
19
  # Arousal - Dominance - Valence
20
 
21
  Dimensional Speech Emotion Recognition model of simultaneous use of [WavLM](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) / [Wav2Vec2.0](https://hf.rst.im/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim).
22
+ Achieves `0.6760566` valence CCC on [MSP Podcast Test 1](https://paperswithcode.com/sota/speech-emotion-recognition-on-msp-podcast). Used as teacher for [Wav2Small](https://arxiv.org/abs/2408.13920).
23
 
24
 
25
 
 
48
  import types
49
  import torch.nn as nn
50
  from transformers import AutoModelForAudioClassification
51
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model,
52
+ Wav2Vec2PreTrainedModel)
 
53
 
54
 
55
  signal = torch.from_numpy(
 
84
  self.classifier = ADV(config)
85
 
86
  def forward(self, x):
 
87
  x -= x.mean(1, keepdim=True)
88
  variance = (x * x).mean(1, keepdim=True) + 1e-7
89
+ x = self.wav2vec2(x / variance.sqrt())
90
+ return self.classifier(x.last_hidden_state.mean(1))
 
91
 
92
 
93
+ def _forward(self, x):
94
  '''x: (batch, audio-samples-16KHz)'''
95
+ x = (x + self.config.mean) / self.config.std # sgn
96
  x = self.ssl_model(x, attention_mask=None).last_hidden_state
97
  # pool
98
  h = self.pool_model.sap_linear(x).tanh()
99
+ w = torch.matmul(h, self.pool_model.attention).softmax(1)
 
100
  mu = (x * w).sum(1)
101
  x = torch.cat(
102
  [
 
111
  base = AutoModelForAudioClassification.from_pretrained(
112
  '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
113
  trust_remote_code=True).to(device).eval()
114
+ base.forward = types.MethodType(_forward, base)
115
 
116
  # Wav2Vec2
117
 
 
124
  return .5 * dawn(x) + .5 * base(x)
125
 
126
  pred = wav2small(signal.to(device))
127
+ print(f'Arousal={pred[0, 0]} '
128
+ f'Dominance={pred[0, 1]} ',
129
+ f'Valence={pred[0, 2]}')
130
  ```