smajumdar commited on
Commit
14db2b1
·
1 Parent(s): 0dfaf95

Finalize HF demo

Browse files

Signed-off-by: smajumdar <[email protected]>

Files changed (1) hide show
  1. app.py +149 -22
app.py CHANGED
@@ -4,23 +4,37 @@ import uuid
4
  import tempfile
5
  import subprocess
6
  import re
 
7
 
8
  import gradio as gr
9
  import pytube as pt
10
 
11
  import nemo.collections.asr as nemo_asr
 
 
12
  import speech_to_text_buffered_infer_ctc as buffered_ctc
13
  import speech_to_text_buffered_infer_rnnt as buffered_rnnt
 
14
 
15
  # Set NeMo cache dir as /tmp
16
  from nemo import constants
17
- os.environ[constants.NEMO_ENV_CACHE_DIR] = "/tmp/nemo"
18
 
 
 
 
 
 
19
 
20
- SAMPLE_RATE = 16000
21
  TITLE = "NeMo ASR Inference on Hugging Face"
22
  DESCRIPTION = "Demo of all languages supported by NeMo ASR"
23
  DEFAULT_EN_MODEL = "nvidia/stt_en_conformer_transducer_xlarge"
 
 
 
 
 
 
 
24
 
25
  MARKDOWN = f"""
26
  # {TITLE}
@@ -32,6 +46,13 @@ CSS = """
32
  p.big {
33
  font-size: 20px;
34
  }
 
 
 
 
 
 
 
35
  """
36
 
37
  ARTICLE = """
@@ -58,6 +79,9 @@ for info in hf_infos:
58
 
59
  SUPPORTED_MODEL_NAMES = sorted(list(SUPPORTED_MODEL_NAMES))
60
 
 
 
 
61
  model_dict = {model_name: gr.Interface.load(f'models/{model_name}') for model_name in SUPPORTED_MODEL_NAMES}
62
 
63
  SUPPORTED_LANG_MODEL_DICT = {}
@@ -77,6 +101,14 @@ for lang in SUPPORTED_LANG_MODEL_DICT.keys():
77
  SUPPORTED_LANG_MODEL_DICT[lang] = model_ids
78
 
79
 
 
 
 
 
 
 
 
 
80
  def parse_duration(audio_file):
81
  """
82
  FFMPEG to calculate durations. Libraries can do it too, but filetypes cause different libraries to behave differently.
@@ -108,7 +140,7 @@ def resolve_model_type(model_name: str) -> str:
108
  return 'ctc'
109
 
110
  # Model specific maps
111
- elif 'jasper' in model_name:
112
  return 'ctc'
113
  elif 'quartznet' in model_name:
114
  return 'ctc'
@@ -116,9 +148,8 @@ def resolve_model_type(model_name: str) -> str:
116
  return 'ctc'
117
  elif 'contextnet' in model_name:
118
  return 'ctc'
119
- else:
120
- # Unknown model type
121
- return None
122
 
123
 
124
  def resolve_model_stride(model_name) -> int:
@@ -185,6 +216,16 @@ def extract_result_from_manifest(filepath, model_name) -> (bool, str):
185
  return False, f"Could not perform inference on model with name : {model_name}"
186
 
187
 
 
 
 
 
 
 
 
 
 
 
188
  def infer_audio(model_name: str, audio_file: str) -> str:
189
  """
190
  Main method that switches from HF inference for small audio files to Buffered CTC/RNNT mode for long audio files.
@@ -195,17 +236,18 @@ def infer_audio(model_name: str, audio_file: str) -> str:
195
 
196
  Returns:
197
  str which is the transcription if successful.
 
198
  """
199
  # Parse the duration of the audio file
200
  duration = parse_duration(audio_file)
201
 
202
- if duration > 60.0: # Longer than one minute; use buffered mode
203
  # Process audio to be of wav type (possible youtube audio)
204
  audio_file = convert_audio(audio_file)
205
 
206
  # If audio file transcoding failed, let user know
207
  if audio_file is None:
208
- return "Failed to convert audio file to wav."
209
 
210
  # Extract audio dir from resolved audio filepath
211
  audio_dir = os.path.split(audio_file)[0]
@@ -214,7 +256,7 @@ def infer_audio(model_name: str, audio_file: str) -> str:
214
  model_stride = resolve_model_stride(model_name)
215
 
216
  if model_stride < 0:
217
- return f"Failed to compute the model stride for model with name : {model_name}"
218
 
219
  # Process model type (CTC/RNNT/Hybrid)
220
  model_type = resolve_model_type(model_name)
@@ -266,7 +308,7 @@ def infer_audio(model_name: str, audio_file: str) -> str:
266
  pass
267
 
268
  if RESULT is None:
269
- return f"Could not parse model type; failed to perform inference with model {model_name}!"
270
 
271
  elif model_type == 'ctc':
272
 
@@ -303,9 +345,10 @@ def infer_audio(model_name: str, audio_file: str) -> str:
303
  return extract_result_from_manifest('output.json', model_name)[-1]
304
 
305
  else:
306
- return f"Could not parse model type; failed to perform inference with model {model_name}!"
307
 
308
  else:
 
309
  if model_name in model_dict:
310
  model = model_dict[model_name]
311
  else:
@@ -317,7 +360,7 @@ def infer_audio(model_name: str, audio_file: str) -> str:
317
  return transcriptions
318
  else:
319
  error = (
320
- f"Could not find model {model_name} in list of available models : "
321
  f"{list([k for k in model_dict.keys()])}"
322
  )
323
  return error
@@ -334,30 +377,60 @@ def transcribe(microphone, audio_file, model_name):
334
  audio_data = microphone
335
 
336
  elif (microphone is None) and (audio_file is None):
337
- return "ERROR: You have to either use the microphone or upload an audio file"
338
 
339
  elif microphone is not None:
340
  audio_data = microphone
341
  else:
342
  audio_data = audio_file
343
 
 
344
  try:
345
  # Use HF API for transcription
 
346
  transcriptions = infer_audio(model_name, audio_data)
 
 
347
 
348
  except Exception as e:
349
  transcriptions = ""
350
- warn_output = warn_output + "\n\n"
 
 
 
 
351
  warn_output += (
352
  f"The model `{model_name}` is currently loading and cannot be used "
353
- f"for transcription.\n"
354
  f"Please try another model or wait a few minutes."
355
  )
356
 
357
- return warn_output + transcriptions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
 
360
  def _return_yt_html_embed(yt_url):
 
361
  video_id = yt_url.split("?v=")[-1]
362
  HTML_str = (
363
  f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
@@ -367,6 +440,7 @@ def _return_yt_html_embed(yt_url):
367
 
368
 
369
  def yt_transcribe(yt_url, model_name):
 
370
  yt = pt.YouTube(yt_url)
371
  html_embed_str = _return_yt_html_embed(yt_url)
372
 
@@ -374,15 +448,57 @@ def yt_transcribe(yt_url, model_name):
374
  file_uuid = str(uuid.uuid4().hex)
375
  file_uuid = f"{tempdir}/{file_uuid}.mp3"
376
 
 
 
 
377
  stream = yt.streams.filter(only_audio=True)[0]
378
  stream.download(filename=file_uuid)
379
 
 
 
 
 
 
 
 
 
380
  text = infer_audio(model_name, file_uuid)
381
 
382
- return html_embed_str, text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
 
385
  def create_lang_selector_component(default_en_model=DEFAULT_EN_MODEL):
 
 
 
 
 
 
 
 
 
 
386
  lang_selector = gr.components.Dropdown(
387
  choices=sorted(list(SUPPORTED_LANGUAGES)), value="en", type="value", label="Languages", interactive=True,
388
  )
@@ -406,6 +522,9 @@ def create_lang_selector_component(default_en_model=DEFAULT_EN_MODEL):
406
  return lang_selector, models_in_lang
407
 
408
 
 
 
 
409
  demo = gr.Blocks(title=TITLE, css=CSS)
410
 
411
  with demo:
@@ -419,9 +538,12 @@ with demo:
419
  lang_selector, models_in_lang = create_lang_selector_component()
420
 
421
  transcript = gr.components.Label(label='Transcript')
 
422
 
423
  run = gr.components.Button('Transcribe')
424
- run.click(transcribe, inputs=[microphone, file_upload, models_in_lang], outputs=[transcript])
 
 
425
 
426
  with gr.Tab("Transcribe Youtube"):
427
  yt_url = gr.components.Textbox(
@@ -429,14 +551,19 @@ with demo:
429
  )
430
 
431
  lang_selector_yt, models_in_lang_yt = create_lang_selector_component(
432
- default_en_model='nvidia/stt_en_conformer_transducer_large'
433
  )
434
 
435
- embedded_video = gr.components.HTML()
 
 
 
436
  transcript = gr.components.Label(label='Transcript')
 
437
 
438
- run = gr.components.Button('Transcribe YouTube')
439
- run.click(yt_transcribe, inputs=[yt_url, models_in_lang_yt], outputs=[embedded_video, transcript])
 
440
 
441
  gr.components.HTML(ARTICLE)
442
 
 
4
  import tempfile
5
  import subprocess
6
  import re
7
+ import time
8
 
9
  import gradio as gr
10
  import pytube as pt
11
 
12
  import nemo.collections.asr as nemo_asr
13
+ import torch
14
+
15
  import speech_to_text_buffered_infer_ctc as buffered_ctc
16
  import speech_to_text_buffered_infer_rnnt as buffered_rnnt
17
+ from nemo.utils import logging
18
 
19
  # Set NeMo cache dir as /tmp
20
  from nemo import constants
 
21
 
22
+ os.environ[constants.NEMO_ENV_CACHE_DIR] = "/tmp/nemo/"
23
+
24
+
25
+ SAMPLE_RATE = 16000 # Default sample rate for ASR
26
+ BUFFERED_INFERENCE_DURATION_THRESHOLD = 60.0 # 60 second and above will require chunked inference.
27
 
 
28
  TITLE = "NeMo ASR Inference on Hugging Face"
29
  DESCRIPTION = "Demo of all languages supported by NeMo ASR"
30
  DEFAULT_EN_MODEL = "nvidia/stt_en_conformer_transducer_xlarge"
31
+ DEFAULT_BUFFERED_EN_MODEL = "nvidia/stt_en_conformer_transducer_large"
32
+
33
+ # Pre-download and cache the model in disk space
34
+ logging.setLevel(logging.ERROR)
35
+ tmp_model = nemo_asr.models.ASRModel.from_pretrained(DEFAULT_BUFFERED_EN_MODEL, map_location='cpu')
36
+ del tmp_model
37
+ logging.setLevel(logging.INFO)
38
 
39
  MARKDOWN = f"""
40
  # {TITLE}
 
46
  p.big {
47
  font-size: 20px;
48
  }
49
+
50
+ /* From https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition/blob/main/app.py */
51
+
52
+ .result {display:flex;flex-direction:column}
53
+ .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%;font-size:20px;}
54
+ .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
55
+ .result_item_error {background-color:#ff7070;color:white;align-self:start}
56
  """
57
 
58
  ARTICLE = """
 
79
 
80
  SUPPORTED_MODEL_NAMES = sorted(list(SUPPORTED_MODEL_NAMES))
81
 
82
+ # DEBUG FILTER
83
+ SUPPORTED_MODEL_NAMES = list(filter(lambda x: "en" in x and "conformer_transducer_large" in x, SUPPORTED_MODEL_NAMES))
84
+
85
  model_dict = {model_name: gr.Interface.load(f'models/{model_name}') for model_name in SUPPORTED_MODEL_NAMES}
86
 
87
  SUPPORTED_LANG_MODEL_DICT = {}
 
101
  SUPPORTED_LANG_MODEL_DICT[lang] = model_ids
102
 
103
 
104
+ def get_device():
105
+ gpu_available = torch.cuda.is_available()
106
+ if gpu_available:
107
+ return torch.cuda.get_device_name()
108
+ else:
109
+ return "CPU"
110
+
111
+
112
  def parse_duration(audio_file):
113
  """
114
  FFMPEG to calculate durations. Libraries can do it too, but filetypes cause different libraries to behave differently.
 
140
  return 'ctc'
141
 
142
  # Model specific maps
143
+ if 'jasper' in model_name:
144
  return 'ctc'
145
  elif 'quartznet' in model_name:
146
  return 'ctc'
 
148
  return 'ctc'
149
  elif 'contextnet' in model_name:
150
  return 'ctc'
151
+
152
+ return None
 
153
 
154
 
155
  def resolve_model_stride(model_name) -> int:
 
216
  return False, f"Could not perform inference on model with name : {model_name}"
217
 
218
 
219
+ def build_html_output(s: str, style: str = "result_item_success"):
220
+ return f"""
221
+ <div class='result'>
222
+ <div class='result_item {style}'>
223
+ {s}
224
+ </div>
225
+ </div>
226
+ """
227
+
228
+
229
  def infer_audio(model_name: str, audio_file: str) -> str:
230
  """
231
  Main method that switches from HF inference for small audio files to Buffered CTC/RNNT mode for long audio files.
 
236
 
237
  Returns:
238
  str which is the transcription if successful.
239
+ str which is HTML output of logs.
240
  """
241
  # Parse the duration of the audio file
242
  duration = parse_duration(audio_file)
243
 
244
+ if duration > BUFFERED_INFERENCE_DURATION_THRESHOLD: # Longer than one minute; use buffered mode
245
  # Process audio to be of wav type (possible youtube audio)
246
  audio_file = convert_audio(audio_file)
247
 
248
  # If audio file transcoding failed, let user know
249
  if audio_file is None:
250
+ return "Error:- Failed to convert audio file to wav."
251
 
252
  # Extract audio dir from resolved audio filepath
253
  audio_dir = os.path.split(audio_file)[0]
 
256
  model_stride = resolve_model_stride(model_name)
257
 
258
  if model_stride < 0:
259
+ return f"Error:- Failed to compute the model stride for model with name : {model_name}"
260
 
261
  # Process model type (CTC/RNNT/Hybrid)
262
  model_type = resolve_model_type(model_name)
 
308
  pass
309
 
310
  if RESULT is None:
311
+ return f"Error:- Could not parse model type; failed to perform inference with model {model_name}!"
312
 
313
  elif model_type == 'ctc':
314
 
 
345
  return extract_result_from_manifest('output.json', model_name)[-1]
346
 
347
  else:
348
+ return f"Error:- Could not parse model type; failed to perform inference with model {model_name}!"
349
 
350
  else:
351
+ # Obtain Gradio Model function from cache of models
352
  if model_name in model_dict:
353
  model = model_dict[model_name]
354
  else:
 
360
  return transcriptions
361
  else:
362
  error = (
363
+ f"Error:- Could not find model {model_name} in list of available models : "
364
  f"{list([k for k in model_dict.keys()])}"
365
  )
366
  return error
 
377
  audio_data = microphone
378
 
379
  elif (microphone is None) and (audio_file is None):
380
+ warn_output = "ERROR: You have to either use the microphone or upload an audio file"
381
 
382
  elif microphone is not None:
383
  audio_data = microphone
384
  else:
385
  audio_data = audio_file
386
 
387
+ time_diff = None
388
  try:
389
  # Use HF API for transcription
390
+ start = time.time()
391
  transcriptions = infer_audio(model_name, audio_data)
392
+ end = time.time()
393
+ time_diff = end - start
394
 
395
  except Exception as e:
396
  transcriptions = ""
397
+ warn_output = warn_output
398
+
399
+ if warn_output != "":
400
+ warn_output += "<br><br>"
401
+
402
  warn_output += (
403
  f"The model `{model_name}` is currently loading and cannot be used "
404
+ f"for transcription.<br>"
405
  f"Please try another model or wait a few minutes."
406
  )
407
 
408
+ # Built HTML output
409
+ if warn_output != "":
410
+ html_output = build_html_output(warn_output, style="result_item_error")
411
+ else:
412
+ if transcriptions.startswith("Error:-"):
413
+ html_output = build_html_output(transcriptions, style="result_item_error")
414
+ else:
415
+ audio_duration = parse_duration(audio_data)
416
+
417
+ output = f"Successfully transcribed on {get_device()} ! <br>" f"Transcription Time : {time_diff: 0.3f} s"
418
+
419
+ if audio_duration > BUFFERED_INFERENCE_DURATION_THRESHOLD:
420
+ output += f""" <br><br>
421
+ Note: Audio duration was {audio_duration: 0.3f} s, so model had to be downloaded, initialized, and then
422
+ buffered inference was used. <br>
423
+
424
+ Please rerun again in order to measure the time taken for just inference with pre-downloaded model. <br>
425
+ """
426
+
427
+ html_output = build_html_output(output)
428
+
429
+ return transcriptions, html_output
430
 
431
 
432
  def _return_yt_html_embed(yt_url):
433
+ """ Obtained from https://huggingface.co/spaces/whisper-event/whisper-demo """
434
  video_id = yt_url.split("?v=")[-1]
435
  HTML_str = (
436
  f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
 
440
 
441
 
442
  def yt_transcribe(yt_url, model_name):
443
+ """ Modified from https://huggingface.co/spaces/whisper-event/whisper-demo """
444
  yt = pt.YouTube(yt_url)
445
  html_embed_str = _return_yt_html_embed(yt_url)
446
 
 
448
  file_uuid = str(uuid.uuid4().hex)
449
  file_uuid = f"{tempdir}/{file_uuid}.mp3"
450
 
451
+ # Download YT Audio temporarily
452
+ download_time_start = time.time()
453
+
454
  stream = yt.streams.filter(only_audio=True)[0]
455
  stream.download(filename=file_uuid)
456
 
457
+ download_time_end = time.time()
458
+
459
+ # Get audio duration
460
+ audio_duration = parse_duration(file_uuid)
461
+
462
+ # Perform transcription
463
+ infer_time_start = time.time()
464
+
465
  text = infer_audio(model_name, file_uuid)
466
 
467
+ infer_time_end = time.time()
468
+
469
+ if text.startswith("Error:-"):
470
+ html_output = build_html_output(text, style='result_item_error')
471
+ else:
472
+ html_output = f"""
473
+ Successfully transcribed on {get_device()} ! <br>
474
+ Audio Download Time : {download_time_end - download_time_start: 0.3f} s <br>
475
+ Transcription Time : {infer_time_end - infer_time_start: 0.3f} s <br>
476
+ """
477
+
478
+ if audio_duration > BUFFERED_INFERENCE_DURATION_THRESHOLD:
479
+ html_output += f""" <br>
480
+ Note: Audio duration was {audio_duration: 0.3f} s, so model had to be downloaded, initialized, and then
481
+ buffered inference was used. <br>
482
+
483
+ Please rerun again in order to measure the time taken for just inference with pre-downloaded model. <br>
484
+ """
485
+
486
+ html_output = build_html_output(html_output)
487
+
488
+ return text, html_embed_str, html_output
489
 
490
 
491
  def create_lang_selector_component(default_en_model=DEFAULT_EN_MODEL):
492
+ """
493
+ Utility function to select a langauge from a dropdown menu, and simultanously update another dropdown
494
+ containing the corresponding model checkpoints for that language.
495
+
496
+ Args:
497
+ default_en_model: str name of a default english model that should be the set default.
498
+
499
+ Returns:
500
+ Gradio components for lang_selector (Dropdown menu) and models_in_lang (Dropdown menu)
501
+ """
502
  lang_selector = gr.components.Dropdown(
503
  choices=sorted(list(SUPPORTED_LANGUAGES)), value="en", type="value", label="Languages", interactive=True,
504
  )
 
522
  return lang_selector, models_in_lang
523
 
524
 
525
+ """
526
+ Define the GUI
527
+ """
528
  demo = gr.Blocks(title=TITLE, css=CSS)
529
 
530
  with demo:
 
538
  lang_selector, models_in_lang = create_lang_selector_component()
539
 
540
  transcript = gr.components.Label(label='Transcript')
541
+ audio_html_output = gr.components.HTML()
542
 
543
  run = gr.components.Button('Transcribe')
544
+ run.click(
545
+ transcribe, inputs=[microphone, file_upload, models_in_lang], outputs=[transcript, audio_html_output]
546
+ )
547
 
548
  with gr.Tab("Transcribe Youtube"):
549
  yt_url = gr.components.Textbox(
 
551
  )
552
 
553
  lang_selector_yt, models_in_lang_yt = create_lang_selector_component(
554
+ default_en_model=DEFAULT_BUFFERED_EN_MODEL
555
  )
556
 
557
+ with gr.Row():
558
+ run = gr.components.Button('Transcribe YouTube')
559
+ embedded_video = gr.components.HTML()
560
+
561
  transcript = gr.components.Label(label='Transcript')
562
+ yt_html_output = gr.components.HTML()
563
 
564
+ run.click(
565
+ yt_transcribe, inputs=[yt_url, models_in_lang_yt], outputs=[transcript, embedded_video, yt_html_output]
566
+ )
567
 
568
  gr.components.HTML(ARTICLE)
569