Update README.md
Browse filesbugfix generation with speaker
README.md
CHANGED
@@ -155,42 +155,24 @@ whisper_turbo_pipe = pipeline(
|
|
155 |
device="cuda",
|
156 |
)
|
157 |
|
158 |
-
def
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
for token_str in speech_tokens_str_list:
|
164 |
-
if token_str.startswith("<|s_") and token_str.endswith("|>"):
|
165 |
-
num_str = token_str[4:-2]
|
166 |
-
try:
|
167 |
-
speech_ids.append(int(num_str))
|
168 |
-
except ValueError:
|
169 |
-
print("Error converting token:", token_str)
|
170 |
-
else:
|
171 |
-
print(f"Unexpected token: {token_str}")
|
172 |
-
return speech_ids
|
173 |
-
|
174 |
|
175 |
waveform, sample_rate = torchaudio.load(sample_audio_path)
|
176 |
|
177 |
max_secs = 15
|
178 |
-
if waveform
|
179 |
-
print("Trimming audio to
|
180 |
-
waveform = waveform[:, : sample_rate *
|
181 |
-
|
182 |
-
waveform = torch.nn.functional.pad(
|
183 |
-
waveform, (0, int(sample_rate * 0.5)), "constant", 0
|
184 |
-
)
|
185 |
|
186 |
-
if waveform.
|
187 |
-
waveform =
|
188 |
|
189 |
-
|
190 |
-
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate,
|
191 |
-
new_freq=16000)
|
192 |
-
waveform = resampler(waveform)
|
193 |
-
sample_rate = 16000
|
194 |
|
195 |
if sample_audio_text is None:
|
196 |
print("Transcribing audio...")
|
@@ -208,44 +190,41 @@ elif len(target_text) > 500:
|
|
208 |
|
209 |
input_text = transcription + " " + target_text
|
210 |
|
211 |
-
|
|
|
|
|
|
|
212 |
|
213 |
-
|
214 |
-
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
|
215 |
-
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
|
216 |
-
]
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
)
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
223 |
|
224 |
-
with torch.no_grad():
|
225 |
outputs = model.generate(
|
226 |
input_ids,
|
227 |
-
max_length=2048,
|
228 |
eos_token_id=speech_end_id,
|
229 |
do_sample=True,
|
230 |
top_p=1,
|
231 |
temperature=0.8,
|
|
|
232 |
)
|
233 |
|
234 |
-
generated_ids = outputs[0][input_ids.shape[1] : -1]
|
235 |
-
|
236 |
-
raw_speech_tokens = tokenizer.batch_decode(generated_ids,
|
237 |
-
skip_special_tokens=True)
|
238 |
-
speech_ids = extract_speech_ids(raw_speech_tokens)
|
239 |
-
|
240 |
-
if len(speech_ids) == 0:
|
241 |
-
raise ValueError("No valid speech tokens were generated!")
|
242 |
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
gen_wav = Codec_model.decode_code(speech_tokens_tensor).cpu().squeeze()
|
247 |
|
248 |
-
|
|
|
|
|
249 |
```
|
250 |
|
251 |
|
|
|
155 |
device="cuda",
|
156 |
)
|
157 |
|
158 |
+
def ids_to_speech_tokens(speech_ids):
|
159 |
+
speech_tokens_str = []
|
160 |
+
for speech_id in speech_ids:
|
161 |
+
speech_tokens_str.append(f"<|s_{speech_id}|>")
|
162 |
+
return speech_tokens_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
waveform, sample_rate = torchaudio.load(sample_audio_path)
|
165 |
|
166 |
max_secs = 15
|
167 |
+
if len(waveform[0]) / sample_rate > 15:
|
168 |
+
print("Warning: Trimming audio to first 15secs.")
|
169 |
+
waveform = waveform[:, : sample_rate * 15]
|
170 |
+
waveform = torch.nn.functional.pad( waveform, (0, int(sample_rate * 0.5)), "constant", 0)
|
|
|
|
|
|
|
171 |
|
172 |
+
if waveform.size(0) > 1:
|
173 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
174 |
|
175 |
+
prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
|
|
|
|
|
|
|
|
176 |
|
177 |
if sample_audio_text is None:
|
178 |
print("Transcribing audio...")
|
|
|
190 |
|
191 |
input_text = transcription + " " + target_text
|
192 |
|
193 |
+
with torch.no_grad():
|
194 |
+
vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
|
195 |
+
vq_code_prompt = vq_code_prompt[0, 0, :]
|
196 |
+
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
|
197 |
|
198 |
+
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
|
|
|
|
|
|
|
199 |
|
200 |
+
chat = [
|
201 |
+
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
|
202 |
+
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + "".join(speech_ids_prefix)}
|
203 |
+
]
|
204 |
+
|
205 |
+
input_ids = tokenizer.apply_chat_template(chat, tokenize=True, return_tensors="pt", continue_final_message=True)
|
206 |
+
input_ids = input_ids.to("cuda")
|
207 |
+
speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
|
208 |
|
|
|
209 |
outputs = model.generate(
|
210 |
input_ids,
|
211 |
+
max_length=2048,
|
212 |
eos_token_id=speech_end_id,
|
213 |
do_sample=True,
|
214 |
top_p=1,
|
215 |
temperature=0.8,
|
216 |
+
min_new_tokens=4, # Fix so the model does not directly stop
|
217 |
)
|
218 |
|
219 |
+
generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix) : -1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
222 |
+
speech_tokens = extract_speech_ids(speech_tokens)
|
223 |
+
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
|
|
|
224 |
|
225 |
+
gen_wav = Codec_model.decode_code(speech_tokens)
|
226 |
+
gen_wav = gen_wav[:, :, prompt_wav.shape[1] :]
|
227 |
+
sf.write(output_filename, gen_wav[0, 0, :].cpu().numpy(), 16000)
|
228 |
```
|
229 |
|
230 |
|