Spaces:
Runtime error
Runtime error
Pavankalyan
commited on
Commit
·
a93df0f
1
Parent(s):
9f02ed6
Upload model.py with huggingface_hub
Browse files
model.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2ForCTC, HubertModel, HubertForCTC
|
4 |
+
#import whisper
|
5 |
+
|
6 |
+
class WhisperModel(nn.Module):
|
7 |
+
def __init__(self, model_type="small.en", n_class=14):
|
8 |
+
super().__init__()
|
9 |
+
self.encoder = whisper.load_model(model_type).encoder
|
10 |
+
|
11 |
+
for param in self.encoder.parameters():
|
12 |
+
param.requires_grad = True
|
13 |
+
|
14 |
+
feature_dim = 768
|
15 |
+
# 512 = tiny.en,
|
16 |
+
# 768 = small.en
|
17 |
+
|
18 |
+
self.intent_classifier = nn.Sequential(
|
19 |
+
nn.Linear(feature_dim, n_class)
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = self.encoder(x)
|
24 |
+
x = torch.mean(x, dim=1)
|
25 |
+
intent = self.intent_classifier(x)
|
26 |
+
return intent
|
27 |
+
|
28 |
+
class Wav2VecModel(nn.Module):
|
29 |
+
def __init__(self, ):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
|
33 |
+
self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h")
|
34 |
+
|
35 |
+
for param in self.encoder.parameters():
|
36 |
+
param.requires_grad = False
|
37 |
+
|
38 |
+
for param in self.encoder.encoder.parameters():
|
39 |
+
param.requires_grad = True
|
40 |
+
|
41 |
+
self.intent_classifier = nn.Sequential(
|
42 |
+
nn.Linear(1024, 14),
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.processor(x, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0).to("cuda")
|
47 |
+
x = self.encoder(x).last_hidden_state
|
48 |
+
x = torch.mean(x, dim=1)
|
49 |
+
logits = self.intent_classifier(x)
|
50 |
+
return logits
|
51 |
+
|
52 |
+
class HubertSSLModel(nn.Module):
|
53 |
+
def __init__(self, ):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
57 |
+
self.encoder = HubertModel.from_pretrained("facebook/hubert-large-ll60k")
|
58 |
+
|
59 |
+
for param in self.encoder.parameters():
|
60 |
+
param.requires_grad = False
|
61 |
+
|
62 |
+
for param in self.encoder.encoder.parameters():
|
63 |
+
param.requires_grad = True
|
64 |
+
|
65 |
+
self.intent_classifier = nn.Sequential(
|
66 |
+
nn.Linear(1024, 14),
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = self.processor(x, sampling_rate=16000, return_tensors="pt")["input_values"].squeeze(0).to("cuda")
|
71 |
+
x = self.encoder(x).last_hidden_state
|
72 |
+
x = torch.mean(x, dim=1)
|
73 |
+
logits = self.intent_classifier(x)
|
74 |
+
return logits
|