SebastianBodza commited on
Commit
a4bab8e
·
verified ·
1 Parent(s): 8832c28

Update README.md

Browse files

bugfix generation with speaker

Files changed (1) hide show
  1. README.md +34 -55
README.md CHANGED
@@ -155,42 +155,24 @@ whisper_turbo_pipe = pipeline(
155
  device="cuda",
156
  )
157
 
158
- def extract_speech_ids(speech_tokens_str_list):
159
- """
160
- Convert tokens like "<|s_12345|>" into integer ids.
161
- """
162
- speech_ids = []
163
- for token_str in speech_tokens_str_list:
164
- if token_str.startswith("<|s_") and token_str.endswith("|>"):
165
- num_str = token_str[4:-2]
166
- try:
167
- speech_ids.append(int(num_str))
168
- except ValueError:
169
- print("Error converting token:", token_str)
170
- else:
171
- print(f"Unexpected token: {token_str}")
172
- return speech_ids
173
-
174
 
175
  waveform, sample_rate = torchaudio.load(sample_audio_path)
176
 
177
  max_secs = 15
178
- if waveform.shape[1] / sample_rate > max_secs:
179
- print("Trimming audio to the first 15 seconds.")
180
- waveform = waveform[:, : sample_rate * max_secs]
181
- # Pad a bit briefly (0.5 sec) at the end
182
- waveform = torch.nn.functional.pad(
183
- waveform, (0, int(sample_rate * 0.5)), "constant", 0
184
- )
185
 
186
- if waveform.shape[0] > 1:
187
- waveform = waveform.mean(dim=0, keepdim=True)
188
 
189
- if sample_rate != 16000:
190
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate,
191
- new_freq=16000)
192
- waveform = resampler(waveform)
193
- sample_rate = 16000
194
 
195
  if sample_audio_text is None:
196
  print("Transcribing audio...")
@@ -208,44 +190,41 @@ elif len(target_text) > 500:
208
 
209
  input_text = transcription + " " + target_text
210
 
211
- formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
 
 
 
212
 
213
- chat = [
214
- {"role": "user", "content": "Convert the text to speech:" + formatted_text},
215
- {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
216
- ]
217
 
218
- input_ids = tokenizer.apply_chat_template(
219
- chat, tokenize=True, return_tensors="pt", continue_final_message=True
220
- )
221
- input_ids = input_ids.to("cuda")
222
- speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
 
 
 
223
 
224
- with torch.no_grad():
225
  outputs = model.generate(
226
  input_ids,
227
- max_length=2048,
228
  eos_token_id=speech_end_id,
229
  do_sample=True,
230
  top_p=1,
231
  temperature=0.8,
 
232
  )
233
 
234
- generated_ids = outputs[0][input_ids.shape[1] : -1]
235
-
236
- raw_speech_tokens = tokenizer.batch_decode(generated_ids,
237
- skip_special_tokens=True)
238
- speech_ids = extract_speech_ids(raw_speech_tokens)
239
-
240
- if len(speech_ids) == 0:
241
- raise ValueError("No valid speech tokens were generated!")
242
 
243
- speech_tokens_tensor = torch.tensor(speech_ids)\
244
- .cuda().unsqueeze(0).unsqueeze(0)
245
-
246
- gen_wav = Codec_model.decode_code(speech_tokens_tensor).cpu().squeeze()
247
 
248
- sf.write(output_filename, gen_wav, 16000)
 
 
249
  ```
250
 
251
 
 
155
  device="cuda",
156
  )
157
 
158
+ def ids_to_speech_tokens(speech_ids):
159
+ speech_tokens_str = []
160
+ for speech_id in speech_ids:
161
+ speech_tokens_str.append(f"<|s_{speech_id}|>")
162
+ return speech_tokens_str
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  waveform, sample_rate = torchaudio.load(sample_audio_path)
165
 
166
  max_secs = 15
167
+ if len(waveform[0]) / sample_rate > 15:
168
+ print("Warning: Trimming audio to first 15secs.")
169
+ waveform = waveform[:, : sample_rate * 15]
170
+ waveform = torch.nn.functional.pad( waveform, (0, int(sample_rate * 0.5)), "constant", 0)
 
 
 
171
 
172
+ if waveform.size(0) > 1:
173
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
174
 
175
+ prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
 
 
 
 
176
 
177
  if sample_audio_text is None:
178
  print("Transcribing audio...")
 
190
 
191
  input_text = transcription + " " + target_text
192
 
193
+ with torch.no_grad():
194
+ vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
195
+ vq_code_prompt = vq_code_prompt[0, 0, :]
196
+ speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
197
 
198
+ formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
 
 
 
199
 
200
+ chat = [
201
+ {"role": "user", "content": "Convert the text to speech:" + formatted_text},
202
+ {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + "".join(speech_ids_prefix)}
203
+ ]
204
+
205
+ input_ids = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors="pt", continue_final_message=True)
206
+ input_ids = input_ids.to("cuda")
207
+ speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
208
 
 
209
  outputs = model.generate(
210
  input_ids,
211
+ max_length=2048,
212
  eos_token_id=speech_end_id,
213
  do_sample=True,
214
  top_p=1,
215
  temperature=0.8,
216
+ min_new_tokens=4, # Fix so the model does not directly stop
217
  )
218
 
219
+ generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix) : -1]
 
 
 
 
 
 
 
220
 
221
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
222
+ speech_tokens = extract_speech_ids(speech_tokens)
223
+ speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
 
224
 
225
+ gen_wav = Codec_model.decode_code(speech_tokens)
226
+ gen_wav = gen_wav[:, :, prompt_wav.shape[1] :]
227
+ sf.write(output_filename, gen_wav[0, 0, :].cpu().numpy(), 16000)
228
  ```
229
 
230