from transformers import pipeline
from multilingual_translation import text_to_text_generation
from utils import lang_ids
import gradio as gr

biogpt_model_list = [
    "microsoft/biogpt",
    "microsoft/BioGPT-Large",
    "microsoft/BioGPT-Large-PubMedQA"
]

lang_model_list = [
    "facebook/m2m100_1.2B",
    "facebook/m2m100_418M"
]

whisper_model_list = [
    "openai/whisper-small",
    "openai/whisper-medium",
    "openai/whisper-tiny",
    "openai/whisper-large"   
]

lang_list = list(lang_ids.keys())

def whisper_demo(input_audio, model_id):
    pipe = pipeline(task="automatic-speech-recognition",model=model_id, device='cuda:0')
    pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language='en', task="transcribe")
    output_text = pipe(input_audio)['text']
    return output_text
    
    
def translate_to_english(prompt, lang_model_id, base_lang):
    if base_lang == "English":
        return prompt
    else:
        text_output = text_to_text_generation(
            prompt=prompt,
            model_id=lang_model_id,
            device='cuda:0',
            target_lang='en'
        )
            
        return text_output[0]


def biogpt_text(
    prompt: str,
    biogpt_model_id: str,
    lang_model_id: str,
    base_lang: str,
):

    en_prompt = translate_to_english(prompt, lang_model_id, base_lang)
    generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0")
    output = generator(en_prompt, max_length=250, num_return_sequences=1, do_sample=True)
    output = output[0]['generated_text']
    if base_lang == "English":
        output_text = output
        
    else: 
        output_text = text_to_text_generation(
            prompt=output,
            model_id=lang_model_id,
            device='cuda:0',
            target_lang=base_lang
        )
            
    return en_prompt, output, output_text


def biogpt_audio(
    input_audio: str,
    biogpt_model_id: str,
    whisper_model_id: str,
    max_length: str,
    num_return_sequences: int
):
    en_prompt = whisper_demo(input_audio=input_audio, model_id=whisper_model_id)
    generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0")
    output = generator(en_prompt, max_length=max_length, num_return_sequences=num_return_sequences, do_sample=True)
    output_dict = {}
    for i in range(num_return_sequences):
        output_dict[str(i+1)] = output[i]['generated_text']

    output_text = ""
    for i in range(num_return_sequences):
        output_text += f'{output_dict[str(i+1)]}\n\n'
    
    return en_prompt, output_text, output_text

examples = [
    ["COVID-19 is", biogpt_model_list[0], lang_model_list[1], "English"]
]

app = gr.Blocks()
with app:
    gr.Markdown("# **<p align='center'>Whisper + M2M100 + BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining</p>**")
    gr.Markdown(
        """
        <p style='text-align: center'>
        Follow me for more! 
        <br> <a href='https://twitter.com/kadirnar_ai' target='_blank'>twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>linkedin</a> | 
        </p>
        """
    )    
    with gr.Row():
        with gr.Column():
            with gr.Tab("Text"):
                input_text = gr.Textbox(lines=3, value="COVID-19 is", label="Text")
                input_text_button = gr.Button(value="Predict")
                input_biogpt_model =gr.Dropdown(choices=biogpt_model_list, value=biogpt_model_list[0], label='BioGpt Model')
                input_m2m100_model =gr.Dropdown(choices=lang_model_list,  value=lang_model_list[1], label='Language Model')
                input_base_lang = gr.Dropdown(lang_list, value="English", label="Base Language")
            
            with gr.Tab("Audio"):
                input_audio = gr.Microphone(label='Audio')
                input_audio_button = gr.Button(value="Predict")    

        with gr.Column():
            prompt_text = gr.Textbox(lines=3, label="Prompt")
            output_text = gr.Textbox(lines=3, label="BioGpt Text")
            translated_text = gr.Textbox(lines=3,label="Translated Text")
                
    gr.Examples(examples, inputs=[input_text, input_biogpt_model, input_m2m100_model,input_base_lang], outputs=[prompt_text, output_text, translated_text], fn=biogpt_text, cache_examples=True)
    input_text_button.click(biogpt_text, inputs=[input_text, input_biogpt_model, input_m2m100_model,input_base_lang], outputs=[prompt_text, output_text, translated_text])
    input_audio_button.click(biogpt_audio, inputs=[input_audio, input_biogpt_model,input_m2m100_model,input_base_lang], outputs=[prompt_text, output_text, translated_text])
    
app.launch()