questgen / app.py
ViXuan's picture
cleaner files
1571051
raw
history blame
2.69 kB
import time
import gradio as gr
from transformers import AutoTokenizer
import os
from pathlib import Path
from FastT5 import get_onnx_runtime_sessions, OnnxT5
trained_model_path = './t5_squad_v1/'
pretrained_model_name = Path(trained_model_path).stem
encoder_path = os.path.join(
trained_model_path, f"{pretrained_model_name}-encoder_quantized.onnx")
decoder_path = os.path.join(
trained_model_path, f"{pretrained_model_name}-decoder_quantized.onnx")
init_decoder_path = os.path.join(
trained_model_path, f"{pretrained_model_name}-init-decoder_quantized.onnx")
model_paths = encoder_path, decoder_path, init_decoder_path
model_sessions = get_onnx_runtime_sessions(model_paths)
model = OnnxT5(trained_model_path, model_sessions)
tokenizer = AutoTokenizer.from_pretrained(trained_model_path)
def get_question(sentence, answer, mdl, tknizer):
text = "context: {} answer: {}".format(sentence, answer)
print(text)
max_len = 256
encoding = tknizer.encode_plus(
text, max_length=max_len, pad_to_max_length=False, truncation=True, return_tensors="pt")
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
outs = mdl.generate(input_ids=input_ids,
attention_mask=attention_mask,
early_stopping=True,
num_beams=5,
num_return_sequences=1,
no_repeat_ngram_size=2,
max_length=300)
dec = [tknizer.decode(ids, skip_special_tokens=True) for ids in outs]
Question = dec[0].replace("question:", "")
Ouestion = Question.strip()
return Question
# context = "Ramsri loves to watch cricket during his free time"
# answer = "cricket"
context = "Donald Trump is an American media personality and businessman who served as the 45th president of the United States."
answer = "Donald Trump"
ques = get_question(context, answer, model, tokenizer)
print("question: ", ques)
context = gr.components.Textbox(
lines=5, placeholder="Enter paragraph/context here...")
answer = gr.components.Textbox(
lines=3, placeholder="Enter answer/keyword here...")
question = gr.components.Textbox(type="text", label="Question")
def generate_question(context, answer):
start_time = time.time() # Record the start time
result = get_question(context, answer, model, tokenizer)
end_time = time.time() # Record the end time
latency = end_time - start_time # Calculate latency
print(f"Latency: {latency} seconds")
return result
iface = gr.Interface(
fn=generate_question,
inputs=[context, answer],
outputs=question
)
iface.launch(share=True)