aai / tabs /tts /events.py
barreloflube's picture
Refactor code to update UI buttons in audio_tab()
fb7b96a
raw
history blame
6.21 kB
import re
import os
import gc
import tempfile
from uuid import uuid4
import spaces
import gradio as gr
import torchaudio
import numpy as np
from df.enhance import enhance, load_audio, save_audio
from config import Config
from .load_models import *
from .modules.CosyVoice.cosyvoice.utils.file_utils import load_wav
# Helper functions
def create_temp_file():
return tempfile.NamedTemporaryFile(delete=False)
def assign_language_tags(text):
return text
# # Process the text
# # based on the language assign <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
# # at the start of the text for that language
# # e.g. input: 你好 Hello こんにちは 你好 안녕하세요
# # output: <|zh|>你好<|en|>Hello<|jp|>こんにちは<|yue|>你好<|ko|>안녕하세요
# # Define language patterns
# patterns = {
# 'zh': r'[\u4e00-\u9fff]+', # Chinese characters
# 'en': r'[a-zA-Z]+', # English letters
# 'jp': r'[\u3040-\u30ff\u31f0-\u31ff]+', # Japanese characters
# 'ko': r'[\uac00-\ud7a3]+', # Korean characters
# }
# # Find all matches
# matches = []
# for lang, pattern in patterns.items():
# for match in re.finditer(pattern, text):
# matches.append((match.start(), match.end(), lang, match.group()))
# # Sort matches by start position
# matches.sort(key=lambda x: x[0])
# # Build the result string
# result = []
# last_end = 0
# zh_count = 0
# for start, end, lang, content in matches:
# if start > last_end:
# result.append(text[last_end:start])
# if lang == 'zh':
# zh_count += 1
# if zh_count > 1:
# lang = 'yue'
# result.append(f'<|{lang}|>{content}')
# last_end = end
# if last_end < len(text):
# result.append(text[last_end:])
# return ''.join(result)
def update_mode(mode, sft_speaker, speaker_audio, voice_instructions):
if mode == 'SFT':
return (
gr.update( # sft_speaker
),
gr.update( # speaker_audio,
visible=False,
),
gr.update( # voice_instructions,
visible=False,
),
)
elif mode == 'VC':
return (
gr.update( # sft_speaker,
visible=False,
),
gr.update( # speaker_audio,
visible=True,
),
gr.update( # voice_instructions,
visible=True,
),
)
elif mode == 'VC-CrossLingual':
return (
gr.update( # sft_speaker,
visible=False,
),
gr.update( # speaker_audio,
visible=True,
),
gr.update( # voice_instructions,
visible=False,
),
)
elif mode == 'Instruct':
return (
gr.update( # sft_speaker,
visible=True,
),
gr.update( # speaker_audio,
visible=False,
),
gr.update( # voice_instructions,
visible=True,
),
)
else:
raise gr.Error('Invalid mode')
@spaces.GPU(duration=10)
def clear_audio(audio: np.ndarray):
# Save the audio file
audio_file = create_temp_file()
np.save(audio_file.name, audio)
# Load the audio file
audio, _ = load_audio(audio_file.name, sr=df_state.sr())
enhanced = enhance(df_model, df_state, audio)
# Save the enhanced audio file
save_audio(audio_file.name, enhanced, df_state.sr())
return gr.update( # speaker_audio, output_audio
value=audio_file.name,
)
@spaces.GPU(duration=20)
def gen_audio(text, mode, sft_speaker = None, speaker_audio = None, voice_instructions = None):
if mode == any(['VC', 'VC-CrossLingual']):
# Save the speaker audio file
speaker_audio_file = create_temp_file()
np.save(speaker_audio_file.name, speaker_audio)
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
else:
speaker_audio_file = None
prompt_speech_16k = None
# Assign language tags
text = assign_language_tags(text)
# Generate the audio
out_file = create_temp_file()
if mode == 'SFT':
if not sft_speaker:
raise gr.Error('Please select a speaker')
for i, j in enumerate(cosyvoice_sft.inference_sft(
tts_text=text,
spk_id=sft_speaker,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'VC':
if not speaker_audio_file:
raise gr.Error('Please upload an audio')
for i, j in enumerate(cosyvoice.inference_zero_shot(
tts_text=text,
prompt_text=voice_instructions,
prompt_speech_16k=prompt_speech_16k,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'VC-CrossLingual':
if not speaker_audio_file:
raise gr.Error('Please upload an audio')
for i, j in enumerate(cosyvoice.inference_cross_lingual(
tts_text=text,
prompt_speech_16k=prompt_speech_16k,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
elif mode == 'Instruct':
if not voice_instructions:
raise gr.Error('Please enter voice instructions')
for i, j in enumerate(cosyvoice_instruct.inference_instruct(
tts_text=text,
spk_id=sft_speaker,
instruct_text=voice_instructions,
)):
torchaudio.save(
out_file.name.format(i),
j['tts_speech'],
22050,
)
return gr.update( # output_audio
value=out_file.name,
)