dkounadis commited on
Commit
e2d1e1b
1 Parent(s): 0b8d088
Files changed (1) hide show
  1. README.md +34 -20
README.md CHANGED
@@ -44,8 +44,10 @@ Florian Eyben, Felix Burkhardt, Björn Schuller.
44
  # Usage
45
  ```python
46
  from transformers import AutoModelForAudioClassification
47
- from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model,
48
- Wav2Vec2PreTrainedModel)
 
 
49
  import torch
50
  import types
51
  import torch.nn as nn
@@ -53,32 +55,45 @@ import torch.nn as nn
53
  signal = torch.rand((1, 16000)) # audio signal 16 KHz
54
  device = 'cpu'
55
 
56
- class ADV(nn.Module):
 
 
 
57
 
58
- def __init__(self):
59
  super().__init__()
60
- self.dense = nn.Linear(1024, 1024)
61
- self.out_proj = nn.Linear(1024, 3)
 
62
 
63
  def forward(self, x):
64
- x = self.dense(x).tanh()
 
 
 
65
  return self.out_proj(x)
66
 
 
67
  class Dawn(Wav2Vec2PreTrainedModel):
 
68
 
69
  def __init__(self, config):
 
70
  super().__init__(config)
 
71
  self.wav2vec2 = Wav2Vec2Model(config)
72
- self.classifier = ADV()
73
 
74
  def forward(self, x):
 
75
  x = x - x.mean(1, keepdim=True)
76
  variance = (x * x).mean(1, keepdim=True) + 1e-7
77
- x = self.wav2vec2(x / variance.sqrt())[0]
78
- return self.classifier(x.mean(1)).clip(0, 1)
 
79
 
80
- def _fast(self, x):
81
- x = (x + self.config.mean) / self.config.std # sign
 
82
  x = self.ssl_model(x, attention_mask=None).last_hidden_state
83
  # pool
84
  h = self.pool_model.sap_linear(x).tanh()
@@ -96,24 +111,23 @@ def _fast(self, x):
96
  # WavLM
97
 
98
  base = AutoModelForAudioClassification.from_pretrained(
99
- '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
100
- trust_remote_code=True).to(device).eval()
101
- base.forward = types.MethodType(_fast, base)
102
 
103
- # Wav2Vec2.0
104
 
105
  dawn = Dawn.from_pretrained(
106
- 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
107
  ).to(device).eval()
108
 
109
 
110
  def wav2small(x):
111
- '''x: (batch, audio-samples-16KHz)'''
112
  return .5 * dawn(x) + .5 * base(x)
113
 
114
 
115
  with torch.no_grad():
116
  pred = wav2small(signal.to(device))
117
- print(f'arousal={pred[0, 0]} dominance={pred[0, 1]}',
118
- f'valence={pred[0, 2]}')
119
  ```
 
44
  # Usage
45
  ```python
46
  from transformers import AutoModelForAudioClassification
47
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
48
+ Wav2Vec2Model,
49
+ Wav2Vec2PreTrainedModel
50
+ )
51
  import torch
52
  import types
53
  import torch.nn as nn
 
55
  signal = torch.rand((1, 16000)) # audio signal 16 KHz
56
  device = 'cpu'
57
 
58
+ class RegressionHead(nn.Module):
59
+ r"""A/D/V"""
60
+
61
+ def __init__(self, config):
62
 
 
63
  super().__init__()
64
+
65
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
66
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
67
 
68
  def forward(self, x):
69
+
70
+ x = self.dense(x)
71
+ x = torch.tanh(x)
72
+
73
  return self.out_proj(x)
74
 
75
+
76
  class Dawn(Wav2Vec2PreTrainedModel):
77
+ r"""https://arxiv.org/abs/2203.07378"""
78
 
79
  def __init__(self, config):
80
+
81
  super().__init__(config)
82
+
83
  self.wav2vec2 = Wav2Vec2Model(config)
84
+ self.classifier = RegressionHead(config)
85
 
86
  def forward(self, x):
87
+ '''x: (batch, audio-samples-16KHz)'''
88
  x = x - x.mean(1, keepdim=True)
89
  variance = (x * x).mean(1, keepdim=True) + 1e-7
90
+ out = self.wav2vec2(x / variance.sqrt())
91
+ return self.classifier(out[0].mean(1)).clip(0, 1)
92
+
93
 
94
+ def _infer(self, x):
95
+ '''x: (batch, audio-samples-16KHz)'''
96
+ x = (x + self.config.mean) / self.config.std # plus
97
  x = self.ssl_model(x, attention_mask=None).last_hidden_state
98
  # pool
99
  h = self.pool_model.sap_linear(x).tanh()
 
111
  # WavLM
112
 
113
  base = AutoModelForAudioClassification.from_pretrained(
114
+ '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
115
+ trust_remote_code=True).to(device).eval()
116
+ base.forward = types.MethodType(_infer, base)
117
 
118
+ # Wav2Vec2
119
 
120
  dawn = Dawn.from_pretrained(
121
+ 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
122
  ).to(device).eval()
123
 
124
 
125
  def wav2small(x):
 
126
  return .5 * dawn(x) + .5 * base(x)
127
 
128
 
129
  with torch.no_grad():
130
  pred = wav2small(signal.to(device))
131
+ print(f'\nArousal = {pred[0, 0]} Dominance = {pred[0, 1]}',
132
+ f' Valence = {pred[0, 2]}')
133
  ```