cahya commited on
Commit
37c396e
·
1 Parent(s): dbebbd9
Files changed (3) hide show
  1. 5gram.bin +3 -0
  2. app.py +28 -3
  3. requirements.txt +3 -1
5gram.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46e982596dbb0c7c225dd9b88ef89c733ba6d718befc3c3b833b1daddc60816a
3
+ size 11939611
app.py CHANGED
@@ -1,11 +1,35 @@
1
  import soundfile as sf
2
  import torch
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
4
  import gradio as gr
5
  import sox
6
  import os
 
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def convert(inputfile, outfile):
10
  sox_tfm = sox.Transformer()
11
  sox_tfm.set_output_format(
@@ -18,6 +42,7 @@ api_token = os.getenv("API_TOKEN")
18
  model_name = "indonesian-nlp/wav2vec2-luganda"
19
  processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=api_token)
20
  model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=api_token)
 
21
 
22
 
23
  def parse_transcription(wav_file):
@@ -25,9 +50,9 @@ def parse_transcription(wav_file):
25
  convert(wav_file.name, filename + "16k.wav")
26
  speech, _ = sf.read(filename + "16k.wav")
27
  input_values = processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
28
- logits = model(input_values).logits
29
- predicted_ids = torch.argmax(logits, dim=-1)
30
- transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
31
  return transcription
32
 
33
 
 
1
  import soundfile as sf
2
  import torch
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
+ from pyctcdecode import build_ctcdecoder
5
  import gradio as gr
6
  import sox
7
  import os
8
+ from multiprocessing import Pool
9
 
10
 
11
+ class KenLM:
12
+ def __init__(self, tokenizer, model_name, num_workers=8, beam_width=128):
13
+ self.num_workers = num_workers
14
+ self.beam_width = beam_width
15
+ vocab_dict = tokenizer.get_vocab()
16
+ self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
17
+ # Workaround for wrong number of vocabularies:
18
+ self.vocabulary = self.vocabulary[:-2]
19
+ self.decoder = build_ctcdecoder(self.vocabulary, model_name)
20
+
21
+ @staticmethod
22
+ def lm_postprocess(text):
23
+ return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip()
24
+
25
+ def decode(self, logits):
26
+ probs = logits.cpu().numpy()
27
+ # probs = logits.numpy()
28
+ with Pool(self.num_workers) as pool:
29
+ text = self.decoder.decode_batch(pool, probs)
30
+ text = [KenLM.lm_postprocess(x) for x in text]
31
+ return text
32
+
33
  def convert(inputfile, outfile):
34
  sox_tfm = sox.Transformer()
35
  sox_tfm.set_output_format(
 
42
  model_name = "indonesian-nlp/wav2vec2-luganda"
43
  processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=api_token)
44
  model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=api_token)
45
+ kenlm = KenLM(processor.tokenizer, "5gram.bin")
46
 
47
 
48
  def parse_transcription(wav_file):
 
50
  convert(wav_file.name, filename + "16k.wav")
51
  speech, _ = sf.read(filename + "16k.wav")
52
  input_values = processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
53
+ with torch.no_grad():
54
+ logits = model(input_values).logits
55
+ transcription = kenlm.decode(logits)[0]
56
  return transcription
57
 
58
 
requirements.txt CHANGED
@@ -3,4 +3,6 @@ soundfile
3
  torch
4
  transformers
5
  sox
6
- sentencepiece
 
 
 
3
  torch
4
  transformers
5
  sox
6
+ sentencepiece
7
+ pyctcdecode==0.3.0
8
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip