polodealvarado commited on
Commit
e68ed8f
1 Parent(s): 0a6eab5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -1
README.md CHANGED
@@ -36,7 +36,47 @@ It achieves the following results on the evaluation set:
36
  - Loss : 0.1900
37
  - Wer : 0.146
38
 
39
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  ## Model description
42
 
 
36
  - Loss : 0.1900
37
  - Wer : 0.146
38
 
39
+ ## Usage with 5-gram.
40
+
41
+ The model can be used with n-gram included in the processor as follows.
42
+
43
+ ```python
44
+ import re
45
+ from transformers import AutoModelForCTC,Wav2Vec2ProcessorWithLM
46
+ import torch
47
+
48
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("polodealvarado/xls-r-300m-es")
49
+ model = AutoModelForCTC.from_pretrained("polodealvarado/xls-r-300m-es")
50
+
51
+ # Cleaning characters
52
+ def remove_extra_chars(batch):
53
+ chars_to_ignore_regex = '[^a-záéíóúñ ]'
54
+ text = batch["translation"][target_lang]
55
+ batch["text"] = re.sub(chars_to_ignore_regex, "", text.lower())
56
+ return batch
57
+
58
+ # Preparing dataset
59
+ def prepare_dataset(batch):
60
+ audio = batch["audio"]
61
+ batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
62
+ with processor.as_target_processor():
63
+ batch["labels"] = processor(batch["sentence"]).input_ids
64
+ return batch
65
+
66
+
67
+ common_voice_test = load_dataset("mozilla-foundation/common_voice_8_0", "es", split="test",use_auth_token=True)
68
+ common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
69
+ common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
70
+ common_voice_test = common_voice_test.map(remove_extra_chars, remove_columns=dataset.column_names)
71
+ common_voice_test = common_voice_test.map(prepare_dataset)
72
+
73
+
74
+ with torch.no_grad():
75
+ logits = model(**inputs).logits
76
+
77
+ pred_ids = torch.argmax(logits, dim=-1)
78
+ text = processor.batch_decode(logits.numpy()).text
79
+ ```
80
 
81
  ## Model description
82