aai / tabs /audios /events.py
barreloflube's picture
Refactor gen_audio function to use CosyVoice TTS instead of SFT
7925940
raw
history blame
6.1 kB
import re
import os
import gc
import tempfile
from uuid import uuid4
import spaces
import gradio as gr
import torchaudio
import numpy as np
from df.enhance import enhance, load_audio, save_audio
from config import Config
from .load_models import *
from .modules.CosyVoice.cosyvoice.utils.file_utils import load_wav
# Helper functions
def create_temp_file():
return tempfile.NamedTemporaryFile(delete=False)
def assign_language_tags(text):
# Process the text
# based on the language assign <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
# at the start of the text for that language
# e.g. input: 你好 Hello こんにちは 你好 안녕하세요
# output: <|zh|>你好<|en|>Hello<|jp|>こんにちは<|yue|>你好<|ko|>안녕하세요
# Define language patterns
patterns = {
'zh': r'[\u4e00-\u9fff]+', # Chinese characters
'en': r'[a-zA-Z]+', # English letters
'jp': r'[\u3040-\u30ff\u31f0-\u31ff]+', # Japanese characters
'ko': r'[\uac00-\ud7a3]+', # Korean characters
}
# Find all matches
matches = []
for lang, pattern in patterns.items():
for match in re.finditer(pattern, text):
matches.append((match.start(), match.end(), lang, match.group()))
# Sort matches by start position
matches.sort(key=lambda x: x[0])
# Build the result string
result = []
last_end = 0
zh_count = 0
for start, end, lang, content in matches:
if start > last_end:
result.append(text[last_end:start])
if lang == 'zh':
zh_count += 1
if zh_count > 1:
lang = 'yue'
result.append(f'<|{lang}|>{content}')
last_end = end
if last_end < len(text):
result.append(text[last_end:])
return ''.join(result)
def update_mode(mode, sft_speaker, speaker_audio, voice_instructions):
if mode == 'SFT':
return (
gr.update( # sft_speaker
),
gr.update( # speaker_audio,
visible=False,
),
gr.update( # voice_instructions,
visible=False,
),
)
elif mode == 'VC':
return (
gr.update( # sft_speaker,
visible=False,
),
gr.update( # speaker_audio,
visible=True,
),
gr.update( # voice_instructions,
visible=True,
),
)
elif mode == 'VC-CrossLingual':
return (
gr.update( # sft_speaker,
visible=False,
),
gr.update( # speaker_audio,
visible=True,
),
gr.update( # voice_instructions,
visible=False,
),
)
elif mode == 'Instruct':
return (
gr.update( # sft_speaker,
visible=True,
),
gr.update( # speaker_audio,
visible=False,
),
gr.update( # voice_instructions,
visible=True,
),
)
else:
raise gr.Error('Invalid mode')
@spaces.GPU(duration=10)
def clear_audio(audio: np.ndarray):
# Save the audio file
audio_file = create_temp_file()
np.save(audio_file.name, audio)
# Load the audio file
audio, _ = load_audio(audio_file.name, sr=df_state.sr())
enhanced = enhance(df_model, df_state, audio)
# Save the enhanced audio file
save_audio(audio_file.name, enhanced, df_state.sr())
return gr.update( # speaker_audio, output_audio
value=audio_file.name,
)
@spaces.GPU(duration=20)
def gen_audio(text, mode, sft_speaker = None, speaker_audio = None, voice_instructions = None):
if mode == any(['VC', 'VC-CrossLingual']):
# Save the speaker audio file
speaker_audio_file = create_temp_file()
np.save(speaker_audio_file.name, speaker_audio)
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
else:
speaker_audio_file = None
prompt_speech_16k = None
# Assign language tags
text = assign_language_tags(text)
# Generate the audio
out_file = create_temp_file()
if mode == 'SFT':
if not sft_speaker:
raise gr.Error('Please select a speaker')
for i, j in enumerate(cv_sft.inference_sft(
tts_text=text,
spk_id=sft_speaker,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'VC':
if not speaker_audio_file:
raise gr.Error('Please upload an audio')
for i, j in enumerate(cv_vc.inference_zero_shot(
tts_text=text,
prompt_text=voice_instructions,
prompt_speech_16k=prompt_speech_16k,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'VC-CrossLingual':
if not speaker_audio_file:
raise gr.Error('Please upload an audio')
for i, j in enumerate(cv_vc.inference_cross_lingual(
tts_text=text,
prompt_speech_16k=prompt_speech_16k,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'Instruct':
if not voice_instructions:
raise gr.Error('Please enter voice instructions')
for i, j in enumerate(cv_instruct.inference_instruct(
tts_text=text,
spk_id=sft_speaker,
instruct_text=voice_instructions,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
return gr.update( # output_audio
value=out_file.name,
)