Update README.md
Browse files
README.md
CHANGED
@@ -45,10 +45,10 @@ More information needed
|
|
45 |
|
46 |
## Training and evaluation data
|
47 |
|
48 |
-
The model was trained on version 7 of the Luganda dataset of Mozilla common voices dataset. We used the train and validation dataset for training and the test dataset for validation.
|
49 |
|
50 |
## Training procedure
|
51 |
-
|
52 |
|
53 |
### Training hyperparameters
|
54 |
|
@@ -82,3 +82,36 @@ The following hyperparameters were used during training:
|
|
82 |
- Pytorch 2.2.1+cu121
|
83 |
- Datasets 2.17.0
|
84 |
- Tokenizers 0.15.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
## Training and evaluation data
|
47 |
|
48 |
+
The model was trained on version 7 of the Luganda dataset of Mozilla common voices dataset. We used the train and validation dataset for training and the test dataset for validation.
|
49 |
|
50 |
## Training procedure
|
51 |
+
We trained the model on a 32 GB V100 GPU for 10 epochs using a learning rate of 5e-05. We used the AdamW optimizer.
|
52 |
|
53 |
### Training hyperparameters
|
54 |
|
|
|
82 |
- Pytorch 2.2.1+cu121
|
83 |
- Datasets 2.17.0
|
84 |
- Tokenizers 0.15.2
|
85 |
+
|
86 |
+
### Usage
|
87 |
+
```python
|
88 |
+
import torch
|
89 |
+
import torchaudio
|
90 |
+
from datasets import load_dataset
|
91 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
92 |
+
|
93 |
+
test_dataset = load_dataset("common_voice", "lg", split="test[:10]")
|
94 |
+
|
95 |
+
processor = Wav2Vec2Processor.from_pretrained("dmusingu/w2v-bert-2.0-luganda-CV-train-validation-7.0")
|
96 |
+
model = Wav2Vec2ForCTC.from_pretrained("dmusingu/w2v-bert-2.0-luganda-CV-train-validation-7.0")
|
97 |
+
|
98 |
+
resampler = torchaudio.transforms.Resample(48_000, 16_000)
|
99 |
+
|
100 |
+
# Preprocessing the datasets.
|
101 |
+
# We need to read the audio files as arrays
|
102 |
+
def speech_file_to_array_fn(batch):
|
103 |
+
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
104 |
+
batch["speech"] = resampler(speech_array).squeeze().numpy()
|
105 |
+
return batch
|
106 |
+
|
107 |
+
test_dataset = test_dataset.map(speech_file_to_array_fn)
|
108 |
+
inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
|
109 |
+
|
110 |
+
with torch.no_grad():
|
111 |
+
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
112 |
+
|
113 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
114 |
+
|
115 |
+
print("Prediction:", processor.batch_decode(predicted_ids))
|
116 |
+
print("Reference:", test_dataset["sentence"][:2])
|
117 |
+
```
|