from transformers import pipeline from multilingual_translation import text_to_text_generation from utils import lang_ids import gradio as gr biogpt_model_list = [ "microsoft/biogpt", "microsoft/BioGPT-Large", "microsoft/BioGPT-Large-PubMedQA" ] lang_model_list = [ "facebook/m2m100_1.2B", "facebook/m2m100_418M" ] whisper_model_list = [ "openai/whisper-small", "openai/whisper-medium", "openai/whisper-tiny", "openai/whisper-large" ] lang_list = list(lang_ids.keys()) def whisper_demo(input_audio, model_id): pipe = pipeline(task="automatic-speech-recognition",model=model_id, device='cuda:0') pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language='en', task="transcribe") output_text = pipe(input_audio)['text'] return output_text def translate_to_english(prompt, lang_model_id, base_lang): if base_lang == "English": return prompt else: output_text = text_to_text_generation( prompt=prompt, model_id=lang_model_id, device='cuda:0', target_lang='en' ) return output_text[0] def biogpt_text( prompt: str, biogpt_model_id: str, lang_model_id: str, base_lang: str, ): en_prompt = translate_to_english(prompt, lang_model_id, base_lang) generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0") output = generator(en_prompt, max_length=250, num_return_sequences=1, do_sample=True) output = output[0]['generated_text'] if base_lang == "English": output_text = output else: output_text = text_to_text_generation( prompt=output, model_id=lang_model_id, device='cuda:0', target_lang=lang_ids[base_lang] ) return en_prompt, output, output_text def biogpt_audio( input_audio: str, biogpt_model_id: str, whisper_model_id: str, base_lang: str, lang_model_id: str, ): en_prompt = whisper_demo(input_audio=input_audio, model_id=whisper_model_id) generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0") output = generator(en_prompt, max_length=250, num_return_sequences=1, do_sample=True) output = output[0]['generated_text'] if base_lang == "English": output_text = output else: output_text = text_to_text_generation( prompt=output, model_id=lang_model_id, device='cuda:0', target_lang=lang_ids[base_lang] ) return en_prompt, output, output_text examples = [["COVID-19 is", biogpt_model_list[0], lang_model_list[1], "English"]] app = gr.Blocks() with app: gr.Markdown("# **

Whisper + M2M100 + BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining

**") gr.Markdown("# **<

twitter | github | linkedin |
**"" ) with gr.Row(): with gr.Column(): with gr.Tab("Text"): input_text = gr.Textbox(lines=3, value="COVID-19 is", label="Text") text_biogpt = gr.Dropdown(choices=biogpt_model_list, value=biogpt_model_list[0], label='BioGpt Model') text_m2m100 = gr.Dropdown(choices=lang_model_list, value=lang_model_list[1], label='Language Model') text_lang = gr.Dropdown(lang_list, value="English", label="Base Language") text_button = gr.Button(value="Predict") with gr.Tab("Audio"): input_audio = gr.Audio(source="microphone", type="filepath", label='Audio') audio_biogpt = gr.Dropdown(choices=biogpt_model_list, value=biogpt_model_list[0], label='BioGpt Model') audio_whisper = gr.Dropdown(choices=whisper_model_list, value=whisper_model_list[0], label='Audio Model') audio_lang = gr.Dropdown(lang_list, value="English", label="Base Language") audio_m2m100 = gr.Dropdown(choices=lang_model_list, value=lang_model_list[1], label='Language Model') audio_button = gr.Button(value="Predict") with gr.Column(): prompt_text = gr.Textbox(lines=3, label="Prompt") output_text = gr.Textbox(lines=3, label="BioGpt Text") translated_text = gr.Textbox(lines=3,label="Translated Text") gr.Examples(examples, inputs=[input_text, text_biogpt, text_m2m100,text_lang], outputs=[prompt_text, output_text, translated_text], fn=biogpt_text, cache_examples=False) text_button.click(biogpt_text, inputs=[input_text, text_biogpt, text_m2m100 ,text_lang], outputs=[prompt_text, output_text, translated_text]) audio_button.click(biogpt_audio, inputs=[input_audio, audio_biogpt, audio_whisper, audio_lang, audio_m2m100], outputs=[prompt_text, output_text, translated_text]) app.launch()