|
import time |
|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
import os |
|
from pathlib import Path |
|
from FastT5 import get_onnx_runtime_sessions, OnnxT5 |
|
|
|
|
|
trained_model_path = './t5_squad_v1/' |
|
|
|
pretrained_model_name = Path(trained_model_path).stem |
|
|
|
encoder_path = os.path.join( |
|
trained_model_path, f"{pretrained_model_name}-encoder_quantized.onnx") |
|
decoder_path = os.path.join( |
|
trained_model_path, f"{pretrained_model_name}-decoder_quantized.onnx") |
|
init_decoder_path = os.path.join( |
|
trained_model_path, f"{pretrained_model_name}-init-decoder_quantized.onnx") |
|
|
|
model_paths = encoder_path, decoder_path, init_decoder_path |
|
model_sessions = get_onnx_runtime_sessions(model_paths) |
|
model = OnnxT5(trained_model_path, model_sessions) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(trained_model_path) |
|
|
|
|
|
def get_question(sentence, answer, mdl, tknizer): |
|
text = "context: {} answer: {}".format(sentence, answer) |
|
print(text) |
|
max_len = 256 |
|
encoding = tknizer.encode_plus( |
|
text, max_length=max_len, pad_to_max_length=False, truncation=True, return_tensors="pt") |
|
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"] |
|
outs = mdl.generate(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
early_stopping=True, |
|
num_beams=5, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
max_length=300) |
|
|
|
dec = [tknizer.decode(ids, skip_special_tokens=True) for ids in outs] |
|
|
|
Question = dec[0].replace("question:", "") |
|
Ouestion = Question.strip() |
|
return Question |
|
|
|
|
|
|
|
|
|
context = "Donald Trump is an American media personality and businessman who served as the 45th president of the United States." |
|
answer = "Donald Trump" |
|
ques = get_question(context, answer, model, tokenizer) |
|
print("question: ", ques) |
|
|
|
|
|
context = gr.components.Textbox( |
|
lines=5, placeholder="Enter paragraph/context here...") |
|
answer = gr.components.Textbox( |
|
lines=3, placeholder="Enter answer/keyword here...") |
|
question = gr.components.Textbox(type="text", label="Question") |
|
|
|
|
|
def generate_question(context, answer): |
|
start_time = time.time() |
|
result = get_question(context, answer, model, tokenizer) |
|
end_time = time.time() |
|
latency = end_time - start_time |
|
print(f"Latency: {latency} seconds") |
|
return result |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_question, |
|
inputs=[context, answer], |
|
outputs=question |
|
) |
|
|
|
iface.launch(share=True) |
|
|