import os
import torch
import se_extractor
from api import BaseSpeakerTTS, ToneColorConverter

ckpt_base_en = 'checkpoints/checkpoints/base_speakers/EN'
ckpt_converter_en = 'checkpoints/checkpoints/converter'
device = 'cuda:0'

base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device)
base_speaker_tts.load_ckpt(f'{ckpt_base_en}/checkpoint.pth')

tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth')

from tts_voice import tts_order_voice
import edge_tts
import gradio as gr
import tempfile
import anyio

def vc_en(text, audio_ref, style_mode):
  if style_mode=="default":
    source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device)
    reference_speaker = audio_ref
    target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
    save_path = "output.wav"

    # Run the base speaker tts
    src_path = "tmp.wav"
    base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0)

    # Run the tone color converter
    encode_message = "@MyShell"
    tone_color_converter.convert(
        audio_src_path=src_path,
        src_se=source_se,
        tgt_se=target_se,
        output_path=save_path,
        message=encode_message)

  else:
    source_se = torch.load(f'{ckpt_base_en}/en_style_se.pth').to(device)
    reference_speaker = audio_ref
    target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)

    save_path = "output.wav"

    # Run the base speaker tts
    src_path = "tmp.wav"
    base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=0.9)

    # Run the tone color converter
    encode_message = "@MyShell"
    tone_color_converter.convert(
        audio_src_path=src_path,
        src_se=source_se,
        tgt_se=target_se,
        output_path=save_path,
        message=encode_message)

  return "output.wav"

language_dict = tts_order_voice

base_speaker = "base_audio.mp3"
source_se, audio_name = se_extractor.get_se(base_speaker, tone_color_converter, vad=True)

async def text_to_speech_edge(text, audio_ref, language_code):
    voice = language_dict[language_code]
    communicate = edge_tts.Communicate(text, voice)
    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
        tmp_path = tmp_file.name

    await communicate.save(tmp_path)

    reference_speaker = audio_ref
    target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
    save_path = "output.wav"

    # Run the tone color converter
    encode_message = "@MyShell"
    tone_color_converter.convert(
        audio_src_path=tmp_path,
        src_se=source_se,
        tgt_se=target_se,
        output_path=save_path,
        message=encode_message)

    return "output.wav"

app = gr.Blocks()

with app:
  gr.Markdown("# <center>🥳💕🎶 OpenVoice 3秒语音情感真实复刻</center>")
  gr.Markdown("## <center>🌟 只需3秒语音,一键复刻说话语气及情感,喜怒哀乐、应有尽有! </center>")
  gr.Markdown("### <center>🌊 更多精彩应用,敬请关注[滔滔AI](http://www.talktalkai.com);滔滔AI,为爱滔滔!💕</center>")
  with gr.Tab("💕语音情感合成"):
    with gr.Row():
      with gr.Column():
        inp1 = gr.Textbox(lines=3, label="请输入您想转换的英文文本")
        inp2 = gr.Audio(label="请上传您喜欢的语音文件", type="filepath")
        inp3 = gr.Dropdown(label="请选择一种语音情感", info="🙂default😊friendly🤫whispering😄cheerful😱terrified😡angry😢sad", choices=["default", "friendly", "whispering", "cheerful", "terrified", "angry", "sad"], value="default")

        btn1 = gr.Button("开始语音情感真实复刻吧!", variant="primary")

      with gr.Column():
        out1 = gr.Audio(label="为您合成的专属语音", type="filepath")
    btn1.click(vc_en, [inp1, inp2, inp3], out1)

  with gr.Tab("🌟多语言声音复刻"):
    with gr.Row():
      with gr.Column():
        inp4 = gr.Textbox(lines=3, label="请输入您想转换的英文文本")
        inp5 = gr.Audio(label="请上传您喜欢的语音文件", type="filepath")
        inp6 = gr.Dropdown(choices=list(language_dict.keys()), value=list(language_dict.keys())[15], label="请选择文本对应的语言")

        btn2 = gr.Button("开始语音情感真实复刻吧!", variant="primary")

      with gr.Column():
        out2 = gr.Audio(label="为您合成的专属语音", type="filepath")
    btn2.click(text_to_speech_edge, [inp4, inp5, inp6], out2)

    gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。Get your OpenAI API Key [here](https://platform.openai.com/api-keys).</center>")
    gr.HTML('''
        <div class="footer">
                    <p>🌊🏞️🎶 - 江水东流急,滔滔无尽声。 明·顾璘
                    </p>
        </div>
    ''')

app.launch(show_error=True)