theodotus commited on
Commit
1b56b73
·
1 Parent(s): 2f1c221

Used Nemo streaming logic

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import librosa
4
  import torch
5
 
6
- from math import floor,ceil
7
  import nemo.collections.asr as nemo_asr
8
 
9
 
@@ -17,11 +17,17 @@ asr_model.encoder.freeze()
17
  asr_model.decoder.freeze()
18
 
19
 
20
- total_buffer = asr_model.cfg["sample_rate"] * 60 // 10
21
- overhead_len = total_buffer // 2
22
- model_stride = 4
 
 
23
 
24
 
 
 
 
 
25
 
26
  def resample(audio):
27
  audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"],
@@ -38,19 +44,13 @@ def model(audio_16k):
38
 
39
 
40
  def decode_predictions(logits_list):
41
- # calc overhead
42
- logits_overhead = logits_list[0].shape[1] * overhead_len / total_buffer / 2
43
- if (logits_overhead * 2 != int(logits_overhead * 2)) and (len(logits_list) != 1):# if first chunk
44
- raise ValueError("Wrong total_buffer")
45
-
46
  # cut overhead
47
  cutted_logits = []
48
  for idx in range(len(logits_list)):
49
- start_cut = 0 if (idx==0) else floor(logits_overhead)
50
- end_cut = 1 if (idx==len(logits_list)-1) else ceil(logits_overhead)
51
- if (logits_overhead == int(logits_overhead)) and (end_cut != 1):
52
- end_cut +=1
53
- logits = logits_list[idx][:, start_cut:-end_cut]
54
  cutted_logits.append(logits)
55
 
56
  # join
 
3
  import librosa
4
  import torch
5
 
6
+ from math import ceil
7
  import nemo.collections.asr as nemo_asr
8
 
9
 
 
17
  asr_model.decoder.freeze()
18
 
19
 
20
+ buffer_len = 8.0
21
+ chunk_len = 4.8
22
+ total_buffer = round(buffer_len * asr_model.cfg.sample_rate)
23
+ overhead_len = round((buffer_len - chunk_len) * asr_model.cfg.sample_rate)
24
+ model_stride = 8
25
 
26
 
27
+ model_stride_in_secs = asr_model.cfg.preprocessor.window_stride * model_stride
28
+ tokens_per_chunk = ceil(chunk_len / model_stride_in_secs)
29
+ mid_delay = ceil((chunk_len + (buffer_len - chunk_len) / 2) / model_stride_in_secs)
30
+
31
 
32
  def resample(audio):
33
  audio_16k, sr = librosa.load(audio, sr = asr_model.cfg["sample_rate"],
 
44
 
45
 
46
  def decode_predictions(logits_list):
47
+ logits_len = logits_list[0].shape[1]
 
 
 
 
48
  # cut overhead
49
  cutted_logits = []
50
  for idx in range(len(logits_list)):
51
+ start_cut = 0 if (idx==0) else logits_len - 1 - mid_delay
52
+ end_cut = -1 if (idx==len(logits_list)-1) else logits_len - 1 - mid_delay + tokens_per_chunk
53
+ logits = logits_list[idx][:, start_cut:end_cut]
 
 
54
  cutted_logits.append(logits)
55
 
56
  # join