cahya commited on
Commit
1b6ec9d
·
1 Parent(s): ce254f5

add model dropdown

Browse files
Files changed (1) hide show
  1. app.py +31 -13
app.py CHANGED
@@ -11,6 +11,20 @@ from gpuinfo import GPUInfo
11
 
12
 
13
  MODEL_NAME = "cahya/whisper-medium-id" # this always needs to stay in line 8 :D sorry for the hackiness
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  lang = "id"
15
  title = "Indonesian Whisperer"
16
  description = "Cross Language Speech to Speech (Indonesian/English to 25 other languages) using OpenAI Whisper and Coqui TTS"
@@ -46,17 +60,18 @@ languages = {
46
 
47
  device = 0 if torch.cuda.is_available() else "cpu"
48
 
49
- pipe = pipeline(
50
- task="automatic-speech-recognition",
51
- model=MODEL_NAME,
52
- chunk_length_s=30,
53
- device=device,
54
- )
55
-
56
- pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
 
57
 
58
 
59
- def transcribe(microphone, file_upload):
60
  warn_output = ""
61
  if (microphone is not None) and (file_upload is not None):
62
  warn_output = (
@@ -80,11 +95,12 @@ default_lang = "en"
80
  coquiTTS = CoquiTTS()
81
 
82
 
83
- def tts(language: str, audio_microphone: str, audio_file: str):
84
  language = languages[language]
 
85
  time_start = time.time()
86
  print(f"### {datetime.now()} TTS", language, audio_file)
87
- transcription = transcribe(audio_microphone, audio_file)
88
  print(f"### {datetime.now()} transcribed:", transcription)
89
  translation = translate(transcription, language, "id")
90
  # return output
@@ -113,6 +129,8 @@ with gr.Blocks() as blocks:
113
  audio_microphone = gr.Audio(label="Microphone", source="microphone", type="filepath", optional=True)
114
  audio_upload = gr.Audio(label="Upload", source="upload", type="filepath", optional=True)
115
  language = gr.Dropdown([lang for lang in languages.keys()], label="Target Language", value="English")
 
 
116
  with gr.Row(): # mobile_collapse=False
117
  submit = gr.Button("Submit", variant="primary")
118
  examples = gr.Examples(examples=["data/Jokowi - 2022.mp3", "data/Soekarno - 1963.mp3", "data/JFK.mp3"],
@@ -131,8 +149,8 @@ with gr.Blocks() as blocks:
131
 
132
  # actions
133
  submit.click(
134
- tts,
135
- [language, audio_microphone, audio_upload],
136
  [text_source, text_target, audio, system_info],
137
  )
138
 
 
11
 
12
 
13
  MODEL_NAME = "cahya/whisper-medium-id" # this always needs to stay in line 8 :D sorry for the hackiness
14
+ whisper_models = {
15
+ "Indonesian Whisper Tiny": {
16
+ "name": "cahya/whisper-tiny-id",
17
+ "pipe": None,
18
+ },
19
+ "Indonesian Whisper Small": {
20
+ "name": "cahya/whisper-small-id",
21
+ "pipe": None,
22
+ },
23
+ "Indonesian Whisper Medium": {
24
+ "name": "cahya/whisper-medium-id",
25
+ "pipe": None,
26
+ },
27
+ }
28
  lang = "id"
29
  title = "Indonesian Whisperer"
30
  description = "Cross Language Speech to Speech (Indonesian/English to 25 other languages) using OpenAI Whisper and Coqui TTS"
 
60
 
61
  device = 0 if torch.cuda.is_available() else "cpu"
62
 
63
+ for model in whisper_models:
64
+ whisper_models[model]["pipe"] = pipeline(
65
+ task="automatic-speech-recognition",
66
+ model=whisper_models[model]["name"],
67
+ chunk_length_s=30,
68
+ device=device,
69
+ )
70
+ whisper_models[model]["pipe"].model.config.forced_decoder_ids = \
71
+ whisper_models[model]["pipe"].tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
72
 
73
 
74
+ def transcribe(pipe, microphone, file_upload):
75
  warn_output = ""
76
  if (microphone is not None) and (file_upload is not None):
77
  warn_output = (
 
95
  coquiTTS = CoquiTTS()
96
 
97
 
98
+ def process(language: str, model: str, audio_microphone: str, audio_file: str):
99
  language = languages[language]
100
+ pipe = whisper_models[model]["pipe"]
101
  time_start = time.time()
102
  print(f"### {datetime.now()} TTS", language, audio_file)
103
+ transcription = transcribe(pipe, audio_microphone, audio_file)
104
  print(f"### {datetime.now()} transcribed:", transcription)
105
  translation = translate(transcription, language, "id")
106
  # return output
 
129
  audio_microphone = gr.Audio(label="Microphone", source="microphone", type="filepath", optional=True)
130
  audio_upload = gr.Audio(label="Upload", source="upload", type="filepath", optional=True)
131
  language = gr.Dropdown([lang for lang in languages.keys()], label="Target Language", value="English")
132
+ model = gr.Dropdown([model for model in whisper_models.keys()],
133
+ label="Whisper Model", value="Indonesian Whisper Medium")
134
  with gr.Row(): # mobile_collapse=False
135
  submit = gr.Button("Submit", variant="primary")
136
  examples = gr.Examples(examples=["data/Jokowi - 2022.mp3", "data/Soekarno - 1963.mp3", "data/JFK.mp3"],
 
149
 
150
  # actions
151
  submit.click(
152
+ process,
153
+ [language, model, audio_microphone, audio_upload],
154
  [text_source, text_target, audio, system_info],
155
  )
156