biogpt-testing / app.py
flash64's picture
Duplicate from kadirnar/BioGpt
e45d82f
raw
history blame
3.22 kB
from transformers import pipeline, set_seed
from transformers import BioGptTokenizer, BioGptForCausalLM
from multilingual_translation import translate
from utils import lang_ids
import gradio as gr
import torch
biogpt_model_list = [
"microsoft/biogpt",
"microsoft/BioGPT-Large",
"microsoft/BioGPT-Large-PubMedQA"
]
lang_model_list = [
"facebook/m2m100_1.2B",
"facebook/m2m100_418M"
]
lang_list = list(lang_ids.keys())
def translate_to_english(text, lang_model_id, base_lang):
if base_lang == "English":
return text
else:
base_lang = lang_ids[base_lang]
new_text = translate(lang_model_id, text, base_lang, "en")
return new_text[0]
def biogpt(
prompt: str,
biogpt_model_id: str,
max_length: str,
num_return_sequences: int,
base_lang: str,
lang_model_id: str
):
en_prompt = translate_to_english(prompt, lang_model_id, base_lang)
generator = pipeline("text-generation", model=biogpt_model_id, device="cuda:0")
output = generator(en_prompt, max_length=max_length, num_return_sequences=num_return_sequences, do_sample=True)
output_dict = {}
for i in range(num_return_sequences):
output_dict[str(i+1)] = output[i]['generated_text']
output_text = ""
for i in range(num_return_sequences):
output_text += f'{output_dict[str(i+1)]}\n\n'
if base_lang == "English":
base_lang_output = output_text
else:
base_lang_output_ = ""
for i in range(num_return_sequences):
base_lang_output_ += f'{translate(lang_model_id, output_dict[str(i+1)], "en", lang_ids[base_lang])[0]}\n\n'
base_lang_output = base_lang_output_
return en_prompt, output_text, base_lang_output
inputs = [
gr.Textbox(lines=5, value="COVID-19 is", label="Prompt"),
gr.Dropdown(biogpt_model_list, value="microsoft/biogpt", label="BioGPT Model ID"),
gr.Slider(minumum=1, maximum=100, value=25, step=1, label="Max Length"),
gr.Slider(minumum=1, maximum=10, value=2, step=1, label="Number of Outputs"),
gr.Dropdown(lang_list, value="English", label="Base Language"),
gr.Dropdown(lang_model_list, value="facebook/m2m100_418M", label="Language Model ID")
]
outputs = [
gr.outputs.Textbox(label="Prompt"),
gr.outputs.Textbox(label="Output"),
gr.outputs.Textbox(label="Translated Output")
]
examples = [
["COVID-19 is", "microsoft/biogpt", 25, 2, "English", "facebook/m2m100_418M"],
["Kanser", "microsoft/biogpt", 25, 2, "Turkish", "facebook/m2m100_1.2B"]
]
title = "M2M100 + BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining"
description = "BioGPT is a domain-specific generative pre-trained Transformer language model for biomedical text generation and mining. BioGPT follows the Transformer language model backbone, and is pre-trained on 15M PubMed abstracts from scratch. Github: github.com/microsoft/BioGPT Paper: https://arxiv.org/abs/2210.10341"
demo_app = gr.Interface(
biogpt,
inputs,
outputs,
title=title,
description=description,
examples=examples,
cache_examples=False,
)
demo_app.launch(debug=True, enable_queue=True)