ssolito commited on
Commit
5bcf187
·
verified ·
1 Parent(s): f37cd83

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +19 -38
  2. whisper.py +218 -12
app.py CHANGED
@@ -1,58 +1,39 @@
1
- import spaces
2
- import torch
3
  import gradio as gr
4
- from AinaTheme import theme
5
- from transformers import pipeline
6
-
7
- MODEL_NAME = "projecte-aina/whisper-large-v3-ca-es-synth-cs"
8
- BATCH_SIZE = 8
9
- device = 0 if torch.cuda.is_available() else "cpu"
10
-
11
- pipe = pipeline(
12
- task="automatic-speech-recognition",
13
- model=MODEL_NAME,
14
- chunk_length_s=30,
15
- device=device,
16
- )
17
-
18
- @spaces.GPU
19
- def transcribe(inputs):
20
  if inputs is None:
21
  raise gr.Error("Cap fitxer d'àudio introduit! Si us plau pengeu un fitxer "\
22
  "o enregistreu un àudio abans d'enviar la vostra sol·licitud")
23
- text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
24
- return text
25
 
 
 
26
 
27
  description_string = "Transcripció automàtica de micròfon o de fitxers d'àudio.\n Aquest demostrador s'ha desenvolupat per"\
28
- " comprovar els models de reconeixement de parla per a móbils. Per ara utilitza el checkpoint "\
29
- f"[{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) i la llibreria de 🤗 Transformers per a la transcripció."
30
-
31
 
32
  def clear():
33
- return (
34
- None
35
- )
36
-
37
 
38
- with gr.Blocks(theme=theme) as demo:
39
  gr.Markdown(description_string)
40
  with gr.Row():
41
  with gr.Column(scale=1):
42
- #input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
43
- input = gr.Audio(sources=["upload"], type="filepath", label="Audio")
44
 
45
  with gr.Column(scale=1):
46
  output = gr.Textbox(label="Output", lines=8)
47
-
48
- with gr.Row(variant="panel"):
49
- clear_btn = gr.Button("Clear")
50
- submit_btn = gr.Button("Submit", variant="primary")
51
-
52
 
53
- submit_btn.click(fn=transcribe, inputs=[input], outputs=[output])
54
- clear_btn.click(fn=clear,inputs=[], outputs=[input], queue=False,)
 
55
 
 
 
56
 
57
  if __name__ == "__main__":
58
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ from whisper import generate
3
+ from AinaTheme import theme
4
+
5
+ USE_V5 = False
6
+
7
+ def transcribe(inputs, model_version):
 
 
 
 
 
 
 
 
 
 
8
  if inputs is None:
9
  raise gr.Error("Cap fitxer d'àudio introduit! Si us plau pengeu un fitxer "\
10
  "o enregistreu un àudio abans d'enviar la vostra sol·licitud")
 
 
11
 
12
+ use_v5 = model_version == "v0.5"
13
+ return generate(audio_path=inputs, use_v5=use_v5)
14
 
15
  description_string = "Transcripció automàtica de micròfon o de fitxers d'àudio.\n Aquest demostrador s'ha desenvolupat per"\
16
+ " comprovar els models de reconeixement de parla per a móbils."
 
 
17
 
18
  def clear():
19
+ return None, "v1.0"
 
 
 
20
 
21
+ with gr.Blocks() as demo:
22
  gr.Markdown(description_string)
23
  with gr.Row():
24
  with gr.Column(scale=1):
25
+ model_version = gr.Dropdown(label="Model Version", choices=["v1.0", "v0.5"], value="v1.0")
26
+ input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
27
 
28
  with gr.Column(scale=1):
29
  output = gr.Textbox(label="Output", lines=8)
 
 
 
 
 
30
 
31
+ with gr.Row(variant="panel"):
32
+ clear_btn = gr.Button("Clear")
33
+ submit_btn = gr.Button("Submit", variant="primary")
34
 
35
+ submit_btn.click(fn=transcribe, inputs=[input, model_version], outputs=[output])
36
+ clear_btn.click(fn=clear, inputs=[], outputs=[input, model_version], queue=False)
37
 
38
  if __name__ == "__main__":
39
+ demo.launch()
whisper.py CHANGED
@@ -1,34 +1,240 @@
1
- import os
2
  from pyannote.audio import Pipeline
3
  from pydub import AudioSegment
 
4
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
5
  import torchaudio
6
  import torch
 
 
 
7
 
8
  device = 0 if torch.cuda.is_available() else "cpu"
9
  torch_dtype = torch.float32
10
 
11
- HF_TOKEN = os.getenv("HF_TOKEN")
12
- MODEL_NAME = "projecte-aina/whisper-large-v3-ca-es-synth-cs"
13
- model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype,token=HF_TOKEN).to(device)
 
 
14
  processor = WhisperProcessor.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
16
 
17
- def generate(audio_path):
18
- input_audio, sample_rate = torchaudio.load(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
19
  input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
20
 
21
  input_speech = input_audio[0]
22
 
23
  input_features = processor(input_speech,
24
- sampling_rate=16_000,
25
- return_attention_mask=True,
26
  return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  pred_ids = model.generate(input_features,
29
  return_timestamps=True,
30
- max_new_tokens=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- output = processor.batch_decode(pred_ids, skip_special_tokens=True)
33
- line = output[0]
34
- return line
 
 
1
  from pyannote.audio import Pipeline
2
  from pydub import AudioSegment
3
+ import os
4
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
5
  import torchaudio
6
  import torch
7
+ import re
8
+ from transformers import pipeline
9
+
10
 
11
  device = 0 if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float32
13
 
14
+
15
+ MODEL_NAME = "openai/whisper-large-v3"
16
+ CKPT = "projecte-aina/whisper-large-v3-tiny-caesar"
17
+ BATCH_SIZE = 1
18
+ model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype).to(device)
19
  processor = WhisperProcessor.from_pretrained(MODEL_NAME)
20
+ pipeline_vad = Pipeline.from_pretrained("pyannote/voice-activity-detection", use_auth_token=os.environ.get("HF_TOKEN"))
21
+ threshold = 10000
22
+ segments_dir = "."
23
+
24
+ pipe = pipeline(
25
+ task="automatic-speech-recognition",
26
+ model=CKPT,
27
+ chunk_length_s=30,
28
+ device=device
29
+ )
30
+
31
+ def post_process_transcription(example_transcription, max_repeats=1):
32
+ segments = re.findall(r'.+?[.,?]', example_transcription)
33
+
34
+ seen = set()
35
+ unique_segments = []
36
+ for segment in segments:
37
+ if segment not in seen:
38
+ unique_segments.append(segment)
39
+ seen.add(segment)
40
+
41
+ final_string = ''.join(unique_segments)
42
+
43
+ tokens = re.findall(r'\b\w+\b[.,!?]?', final_string)
44
+
45
+ cleaned_tokens = []
46
+ repetition_count = 0
47
+ previous_token = None
48
+
49
+ for token in tokens:
50
+ reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
51
+
52
+ if reduced_token == previous_token:
53
+ repetition_count += 1
54
+ if repetition_count <= max_repeats:
55
+ cleaned_tokens.append(reduced_token)
56
+ else:
57
+ repetition_count = 1
58
+ cleaned_tokens.append(reduced_token)
59
+
60
+ previous_token = reduced_token
61
 
62
+ cleaned_transcription = " ".join(cleaned_tokens)
63
+ cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
64
 
65
+ return cleaned_transcription
66
+
67
+ def convert_forced_to_tokens(forced_decoder_ids):
68
+ forced_decoder_tokens = []
69
+ for i, (idx, token) in enumerate(forced_decoder_ids):
70
+ if token is not None:
71
+ forced_decoder_tokens.append([idx, processor.tokenizer.decode(token)])
72
+ else:
73
+ forced_decoder_tokens.append([idx, token])
74
+ return forced_decoder_tokens
75
+
76
+ def generate_1st_chunk(audio):
77
+
78
+ input_audio, sample_rate = torchaudio.load(audio)
79
  input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
80
 
81
  input_speech = input_audio[0]
82
 
83
  input_features = processor(input_speech,
84
+ sampling_rate=16_000,
 
85
  return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
86
+
87
+ forced_decoder_ids = []
88
+ forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
89
+ forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
90
+ forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']
91
+
92
+ forced_decoder_ids_modified = forced_decoder_ids
93
+ idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
94
+ forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
95
+ prompt = "Antes de 'digui'm', '112'. 112, digui'm. Hola, puc parlar en castellà? Sí, digui, diga. Sí, mire: a veces al abrir la puerta de mi piso tengo una persona ahí. Vale, avisamos a la Guàrdia Urbana, ¿de acuerdo? Vale, perfecto. Gracias. Gracias. Buen día."
96
+ prompt_tokens = processor.tokenizer(prompt, add_special_tokens=False).input_ids
97
+
98
+ # we need to force these tokens
99
+ forced_decoder_ids = []
100
+ for idx, token in enumerate(prompt_tokens):
101
+ # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
102
+ forced_decoder_ids.append([idx + 1, token])
103
+
104
+ # now we add the SOS token at the end
105
+ offset = len(forced_decoder_ids)
106
+ forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])
107
+
108
+ # now we need to append the rest of the prefix tokens (lang, task, timestamps)
109
+ offset = len(forced_decoder_ids)
110
+ for idx, token in forced_decoder_ids_modified:
111
+ forced_decoder_ids.append([idx + offset , token])
112
+
113
+ model.generation_config.forced_decoder_ids = forced_decoder_ids
114
+
115
+ pred_ids = model.generate(input_features,
116
+ return_timestamps=True,
117
+ max_new_tokens=128,
118
+ decoder_start_token_id=forced_bos_token_id)
119
+ #exclude prompt from output
120
+ forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
121
+ output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)
122
+
123
+ return output[1:]
124
+
125
+ def generate_2nd_chuk(audio):
126
+
127
+ input_audio, sample_rate = torchaudio.load(audio)
128
+ input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
129
 
130
+ input_speech = input_audio[0]
131
+
132
+ input_features = processor(input_speech,
133
+ sampling_rate=16_000,
134
+ return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
135
+ forced_decoder_ids = []
136
+
137
+ forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
138
+ forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
139
+ forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']
140
+
141
+ forced_decoder_ids_modified = forced_decoder_ids
142
+ idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
143
+ forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
144
+
145
+ prompt = "112, digui'm. Hola, puc parlar en castellà? Sí, digui, diga. Sí, mire: a veces al abrir la puerta de mi piso tengo una persona ahí. Vale, avisamos a la Guàrdia Urbana, ¿de acuerdo? Vale, perfecto. Gracias. Gracias. Buen día."
146
+ prompt_tokens = processor.tokenizer(prompt, add_special_tokens=False).input_ids
147
+
148
+ # we need to force these tokens
149
+ forced_decoder_ids = []
150
+ for idx, token in enumerate(prompt_tokens):
151
+ # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
152
+ forced_decoder_ids.append([idx + 1, token])
153
+
154
+ # now we add the SOS token at the end
155
+ offset = len(forced_decoder_ids)
156
+ forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])
157
+
158
+ # now we need to append the rest of the prefix tokens (lang, task, timestamps)
159
+ offset = len(forced_decoder_ids)
160
+ for idx, token in forced_decoder_ids_modified:
161
+ forced_decoder_ids.append([idx + offset , token])
162
+
163
+ model.generation_config.forced_decoder_ids = forced_decoder_ids
164
+
165
  pred_ids = model.generate(input_features,
166
  return_timestamps=True,
167
+ max_new_tokens=128,
168
+ decoder_start_token_id=forced_bos_token_id)
169
+ #exclude prompt from output
170
+ forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
171
+ output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)
172
+
173
+ return output[1:]
174
+
175
+ def processing_vad_threshold(audio, output_vad, threshold, max_duration, concatenated_segment):
176
+
177
+ transcription_audio = ""
178
+ is_first_chunk = True
179
+ for speech in output_vad.get_timeline().support():
180
+ start, end = speech.start, speech.end
181
+ segment_duration = (end - start) * 1000
182
+ segment_audio = audio[start * 1000:end * 1000]
183
+
184
+ if max_duration + segment_duration < threshold:
185
+ concatenated_segment += audio[start * 1000:end * 1000]
186
+ max_duration += segment_duration
187
+
188
+ else:
189
+ if len(concatenated_segment) > 0:
190
+ temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
191
+ concatenated_segment.export(temp_segment_path, format="wav")
192
+
193
+ if is_first_chunk:
194
+ output = generate_1st_chunk(temp_segment_path)
195
+ is_first_chunk = False
196
+ else:
197
+ output = generate_2nd_chuk(temp_segment_path)
198
+ transcription_audio = transcription_audio + output
199
+ max_duration = segment_duration
200
+ concatenated_segment = segment_audio
201
+
202
+ # Process any remaining audio in the concatenated_segment
203
+ if len(concatenated_segment) > 0:
204
+ temp_segment_path = os.path.join(segments_dir, f"temp_segment.wav")
205
+ concatenated_segment.export(temp_segment_path, format="wav")
206
+
207
+ output = generate_2nd_chuk(temp_segment_path)
208
+ transcription_audio = transcription_audio + output
209
+
210
+ return(transcription_audio)
211
+
212
+ def format_audio(audio):
213
+ input_audio, sample_rate = torchaudio.load(audio)
214
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
215
+ input_audio = resampler(input_audio)
216
+ input_audio = input_audio.squeeze().numpy()
217
+ return(input_audio)
218
+
219
+ def transcribe_pipeline(audio, task):
220
+ text = pipe(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
221
+ return text
222
+
223
+ def generate(audio_path, use_v5):
224
+ audio = AudioSegment.from_wav(audio_path)
225
+
226
+ output_vad = pipeline_vad(audio_path)
227
+ concatenated_segment = AudioSegment.empty()
228
+ max_duration = 0
229
+
230
+ if use_v5:
231
+ output = processing_vad_threshold(audio, output_vad, threshold, max_duration, concatenated_segment)
232
+ else:
233
+ task = "transcribe"
234
+ output = transcribe_pipeline(format_audio(audio), task)
235
+
236
+ clean_output = post_process_transcription(output, max_repeats=1)
237
+
238
+ return clean_output
239
 
240
+