polodealvarado
commited on
Commit
•
e68ed8f
1
Parent(s):
0a6eab5
Update README.md
Browse files
README.md
CHANGED
@@ -36,7 +36,47 @@ It achieves the following results on the evaluation set:
|
|
36 |
- Loss : 0.1900
|
37 |
- Wer : 0.146
|
38 |
|
39 |
-
## Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
## Model description
|
42 |
|
|
|
36 |
- Loss : 0.1900
|
37 |
- Wer : 0.146
|
38 |
|
39 |
+
## Usage with 5-gram.
|
40 |
+
|
41 |
+
The model can be used with n-gram included in the processor as follows.
|
42 |
+
|
43 |
+
```python
|
44 |
+
import re
|
45 |
+
from transformers import AutoModelForCTC,Wav2Vec2ProcessorWithLM
|
46 |
+
import torch
|
47 |
+
|
48 |
+
processor = Wav2Vec2ProcessorWithLM.from_pretrained("polodealvarado/xls-r-300m-es")
|
49 |
+
model = AutoModelForCTC.from_pretrained("polodealvarado/xls-r-300m-es")
|
50 |
+
|
51 |
+
# Cleaning characters
|
52 |
+
def remove_extra_chars(batch):
|
53 |
+
chars_to_ignore_regex = '[^a-záéíóúñ ]'
|
54 |
+
text = batch["translation"][target_lang]
|
55 |
+
batch["text"] = re.sub(chars_to_ignore_regex, "", text.lower())
|
56 |
+
return batch
|
57 |
+
|
58 |
+
# Preparing dataset
|
59 |
+
def prepare_dataset(batch):
|
60 |
+
audio = batch["audio"]
|
61 |
+
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
|
62 |
+
with processor.as_target_processor():
|
63 |
+
batch["labels"] = processor(batch["sentence"]).input_ids
|
64 |
+
return batch
|
65 |
+
|
66 |
+
|
67 |
+
common_voice_test = load_dataset("mozilla-foundation/common_voice_8_0", "es", split="test",use_auth_token=True)
|
68 |
+
common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
|
69 |
+
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
|
70 |
+
common_voice_test = common_voice_test.map(remove_extra_chars, remove_columns=dataset.column_names)
|
71 |
+
common_voice_test = common_voice_test.map(prepare_dataset)
|
72 |
+
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
logits = model(**inputs).logits
|
76 |
+
|
77 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
78 |
+
text = processor.batch_decode(logits.numpy()).text
|
79 |
+
```
|
80 |
|
81 |
## Model description
|
82 |
|