amiriparian
commited on
Commit
•
d995ff8
1
Parent(s):
1a7cb05
Update README.md
Browse files
README.md
CHANGED
@@ -55,15 +55,14 @@ Further details are available in the corresponding [**paper**](https://arxiv.org
|
|
55 |
```python
|
56 |
import torch
|
57 |
import torch.nn as nn
|
58 |
-
from transformers import
|
59 |
|
60 |
|
61 |
|
62 |
# CONFIG and MODEL SETUP
|
63 |
model_name = 'amiriparian/HuBERT-EmoSet'
|
64 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
|
65 |
-
|
66 |
-
model.classifier = nn.Linear(in_features=256,out_features=6)
|
67 |
|
68 |
sampling_rate=16000
|
69 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
55 |
```python
|
56 |
import torch
|
57 |
import torch.nn as nn
|
58 |
+
from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
|
59 |
|
60 |
|
61 |
|
62 |
# CONFIG and MODEL SETUP
|
63 |
model_name = 'amiriparian/HuBERT-EmoSet'
|
64 |
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-base-ls960")
|
65 |
+
AutoModelForAudioClassification.from_pretrained(model_name, trust_remote_code=True)
|
|
|
66 |
|
67 |
sampling_rate=16000
|
68 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|