Spaces:
Sleeping
Sleeping
from transformers import pipeline, set_seed | |
from transformers import BioGptTokenizer, BioGptForCausalLM | |
from multilingual_translation import translate | |
from utils import lang_ids | |
import gradio as gr | |
biogpt_model_list = [ | |
"microsoft/biogpt", | |
"microsoft/BioGPT-Large-PubMedQA" | |
] | |
lang_model_list = [ | |
"facebook/m2m100_1.2B", | |
"facebook/m2m100_418M" | |
] | |
lang_list = list(lang_ids.keys()) | |
def translate_to_english(text, lang_model_id, base_lang): | |
if base_lang == "English": | |
return text | |
else: | |
base_lang = lang_ids[base_lang] | |
new_text = translate(lang_model_id, text, base_lang, "en") | |
return new_text[0] | |
def biogpt( | |
prompt: str, | |
biogpt_model_id: str, | |
max_length: str, | |
num_return_sequences: int, | |
base_lang: str, | |
lang_model_id: str | |
): | |
en_prompt = translate_to_english(prompt, lang_model_id, base_lang) | |
model = BioGptForCausalLM.from_pretrained(biogpt_model_id) | |
tokenizer = BioGptTokenizer.from_pretrained(biogpt_model_id) | |
generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
set_seed(42) | |
output = generator(en_prompt, max_length=max_length, num_return_sequences=num_return_sequences, do_sample=True) | |
output_dict = {} | |
for i in range(num_return_sequences): | |
output_dict[str(i+1)] = output[i]['generated_text'] | |
output_text = "" | |
for i in range(num_return_sequences): | |
output_text += f'{output_dict[str(i+1)]}\n\n' | |
if base_lang == "English": | |
base_lang_output = output_text | |
else: | |
base_lang_output_ = "" | |
for i in range(num_return_sequences): | |
base_lang_output_ += f'{translate(lang_model_id, output_dict[str(i+1)], "en", lang_ids[base_lang])[0]}\n\n' | |
base_lang_output = base_lang_output_ | |
return en_prompt, output_text, base_lang_output | |
inputs = [ | |
gr.Textbox(lines=5, value="COVID-19 is", label="Prompt"), | |
gr.Dropdown(biogpt_model_list, value="microsoft/biogpt", label="BioGPT Model ID"), | |
gr.Slider(minumum=1, maximum=100, value=25, step=1, label="Max Length"), | |
gr.Slider(minumum=1, maximum=10, value=2, step=1, label="Number of Outputs"), | |
gr.Dropdown(lang_list, value="English", label="Base Language"), | |
gr.Dropdown(lang_model_list, value="facebook/m2m100_418M", label="Language Model ID") | |
] | |
outputs = [ | |
gr.outputs.Textbox(label="Prompt"), | |
gr.outputs.Textbox(label="Output"), | |
gr.outputs.Textbox(label="Translated Output") | |
] | |
examples = [ | |
["COVID-19 is", "microsoft/biogpt", 25, 2, "English", "facebook/m2m100_418M"], | |
["Kanser", "microsoft/biogpt", 25, 2, "Turkish", "facebook/m2m100_1.2B"] | |
] | |
title = "M2M100 + BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining" | |
description = "BioGPT is a domain-specific generative pre-trained Transformer language model for biomedical text generation and mining. BioGPT follows the Transformer language model backbone, and is pre-trained on 15M PubMed abstracts from scratch. Github: github.com/microsoft/BioGPT Paper: https://arxiv.org/abs/2210.10341" | |
demo_app = gr.Interface( | |
biogpt, | |
inputs, | |
outputs, | |
title=title, | |
description=description, | |
examples=examples, | |
cache_examples=True, | |
) | |
demo_app.launch(debug=True, enable_queue=True) | |