asahi417 commited on
Commit
12b8205
β€’
1 Parent(s): e328088

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -27,6 +27,7 @@ pipe = pipeline(
27
  task="automatic-speech-recognition",
28
  model=MODEL_NAME,
29
  chunk_length_s=CHUNK_LENGTH_S,
 
30
  torch_dtype=torch_dtype,
31
  device=device,
32
  model_kwargs=model_kwargs
@@ -40,7 +41,7 @@ def transcribe(inputs, prompt):
40
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
41
  prompt = "。" if not prompt else prompt
42
  generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
43
- result = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs=generate_kwargs)["text"]
44
  return result['text'][1 + len(prompt) + 1:]
45
 
46
  def _return_yt_html_embed(yt_url):
@@ -85,7 +86,7 @@ def yt_transcribe(yt_url, prompt, max_filesize=75.0):
85
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
86
  prompt = "。" if not prompt else prompt
87
  generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
88
- result = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs=generate_kwargs)["text"]
89
  return html_embed_str, result['text'][1 + len(prompt) + 1:]
90
 
91
 
 
27
  task="automatic-speech-recognition",
28
  model=MODEL_NAME,
29
  chunk_length_s=CHUNK_LENGTH_S,
30
+ batch_size=BATCH_SIZE,
31
  torch_dtype=torch_dtype,
32
  device=device,
33
  model_kwargs=model_kwargs
 
41
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
42
  prompt = "。" if not prompt else prompt
43
  generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
44
+ result = pipe(inputs, generate_kwargs=generate_kwargs)["text"]
45
  return result['text'][1 + len(prompt) + 1:]
46
 
47
  def _return_yt_html_embed(yt_url):
 
86
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
87
  prompt = "。" if not prompt else prompt
88
  generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
89
+ result = pipe(inputs, generate_kwargs=generate_kwargs)["text"]
90
  return html_embed_str, result['text'][1 + len(prompt) + 1:]
91
 
92