|
import gradio as gr |
|
import torch |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
from transformers import ( |
|
MT5ForConditionalGeneration, |
|
MT5TokenizerFast, |
|
) |
|
|
|
model = MT5ForConditionalGeneration.from_pretrained( |
|
"minjibi/qa", |
|
return_dict=True, |
|
) |
|
tokenizer = MT5TokenizerFast.from_pretrained( |
|
"minjibi/qa" |
|
) |
|
|
|
def predict(text): |
|
input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True) |
|
generated_ids = model.generate( |
|
input_ids=input_ids, |
|
num_beams=5, |
|
max_length=1000, |
|
repetition_penalty=3.0, |
|
length_penalty=1.0, |
|
early_stopping=True, |
|
top_p=50, |
|
top_k=20, |
|
num_return_sequences=3, |
|
) |
|
|
|
preds = [ |
|
tokenizer.decode( |
|
g, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
for g in generated_ids |
|
] |
|
|
|
output = [text.replace('A', 'Answer') for text in preds] |
|
|
|
final_str = '\n'.join([f"{i+1}. Question: {s.split('Answer')[0].strip()}\n Answer{s.split('Answer')[1].strip()}" for i, s in enumerate(output)]) |
|
|
|
return final_str |
|
examples = [ |
|
["ไพธอนเป็นภาษาการเขียนโปรแกรมที่มีการตีความระดับสูง ภาษาคอมพิวเตอร์นี้สร้างโดย Guido van Rossum และเปิดตัวครั้งแรกในปี 1991"], |
|
["แมว ชื่อวิทยาศาสตร์ (Felis catus) เป็นสปีชีส์สัตว์เลี้ยงของสัตว์เลี้ยงลูกด้วยนมกินเนื้อขนาดเล็ก โดยเป็นแมวสปีชีส์เดียวในวงศ์ Felidae ที่ถูกปรับเป็นสัตว์เลี้ยง และมักเรียกเป็น แมวบ้าน เพื่อแยกมันจากสมาชิกที่อยู่ในป่า"], |
|
] |
|
|
|
|
|
iface = gr.Interface(fn=predict, inputs="text", outputs="text", examples=examples) |
|
iface.launch() |