File size: 3,311 Bytes
c79091a
 
7cbe3b6
 
c79091a
 
328a923
c79091a
 
 
 
328a923
 
 
 
 
7cbe3b6
 
328a923
5ee913f
7cbe3b6
 
5ee913f
328a923
b6ad8fb
7cbe3b6
c79091a
 
328a923
6fd8f41
 
 
 
c79091a
 
b6ad8fb
328a923
 
c79091a
 
7cbe3b6
b6ad8fb
 
 
a906b66
b6ad8fb
 
 
a906b66
b6ad8fb
 
 
6499eee
a906b66
6499eee
a906b66
 
 
b6ad8fb
 
c79091a
 
 
b6ad8fb
 
 
 
328a923
b6ad8fb
c79091a
 
42a3133
 
b6ad8fb
 
42a3133
c79091a
 
b6ad8fb
 
c79091a
843d693
6499eee
 
843d693
 
c79091a
 
 
 
 
843d693
c79091a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from transformers import pipeline, set_seed
from transformers import BioGptTokenizer, BioGptForCausalLM
from multilingual_translation import translate
from utils import lang_ids
import gradio as gr

biogpt_model_list = [
    "microsoft/biogpt",
    "microsoft/BioGPT-Large-PubMedQA"
]

lang_model_list = [
    "facebook/m2m100_1.2B",
    "facebook/m2m100_418M"
]

lang_list = list(lang_ids.keys())

def translate_to_english(text, lang_model_id, base_lang):
    if base_lang == "English":
        return text
    else:
        base_lang = lang_ids[base_lang]
        new_text = translate(lang_model_id, text, base_lang, "en")
        return new_text[0]

def biogpt(
    prompt: str,
    biogpt_model_id: str,
    max_length: str,
    num_return_sequences: int,
    base_lang: str,
    lang_model_id: str
):
    
    en_prompt = translate_to_english(prompt, lang_model_id, base_lang)
    model = BioGptForCausalLM.from_pretrained(biogpt_model_id)
    tokenizer = BioGptTokenizer.from_pretrained(biogpt_model_id)
    generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
    set_seed(42)
    output = generator(en_prompt, max_length=max_length, num_return_sequences=num_return_sequences, do_sample=True)
    output_dict = {}
    for i in range(num_return_sequences):
        output_dict[str(i+1)] = output[i]['generated_text']

    output_text = ""
    for i in range(num_return_sequences):
        output_text += f'{output_dict[str(i+1)]}\n\n'
    
    if base_lang == "English":
        base_lang_output = output_text
        
    else: 
        base_lang_output_ = ""
        for i in range(num_return_sequences):
            base_lang_output_ += f'{translate(lang_model_id, output_dict[str(i+1)], "en", lang_ids[base_lang])[0]}\n\n'
        base_lang_output = base_lang_output_
    
    
    return en_prompt, output_text, base_lang_output


inputs = [
    gr.Textbox(lines=5, value="COVID-19 is", label="Prompt"),
    gr.Dropdown(biogpt_model_list, value="microsoft/biogpt", label="BioGPT Model ID"),
    gr.Slider(minumum=1, maximum=100, value=25, step=1, label="Max Length"),
    gr.Slider(minumum=1, maximum=10, value=2, step=1, label="Number of Outputs"),
    gr.Dropdown(lang_list, value="English", label="Base Language"),
    gr.Dropdown(lang_model_list, value="facebook/m2m100_418M", label="Language Model ID")
    ]

outputs = [
    gr.outputs.Textbox(label="Prompt"),
    gr.outputs.Textbox(label="Output"),
    gr.outputs.Textbox(label="Translated Output")
]

examples = [
    ["COVID-19 is", "microsoft/biogpt", 25, 2, "English", "facebook/m2m100_418M"],
    ["Kanser", "microsoft/biogpt", 25, 2, "Turkish", "facebook/m2m100_1.2B"]
]

title = "M2M100 + BioGPT: Generative Pre-trained Transformer for Biomedical Text Generation and Mining"

description = "BioGPT is a domain-specific generative pre-trained Transformer language model for biomedical text generation and mining. BioGPT follows the Transformer language model backbone, and is pre-trained on 15M PubMed abstracts from scratch. Github: github.com/microsoft/BioGPT Paper: https://arxiv.org/abs/2210.10341" 

demo_app = gr.Interface(
    biogpt, 
    inputs, 
    outputs, 
    title=title, 
    description=description,
    examples=examples, 
    cache_examples=True,
)
demo_app.launch(debug=True, enable_queue=True)