blogpost-cqa / app.py
rgallardo's picture
Update gradio version
26c120c
raw
history blame contribute delete
No virus
2.52 kB
from transformers import AutoTokenizer
import time
import gradio as gr
from optimum.onnxruntime import ORTModelForSeq2SeqLM
from optimum.utils import NormalizedConfigManager
@classmethod
def _new_get_normalized_config_class(cls, model_type):
return cls._conf["t5"]
NormalizedConfigManager.get_normalized_config_class = _new_get_normalized_config_class
N = 2 # Number of previous QA pairs to use for context
MAX_NEW_TOKENS = 128 # Maximum number of tokens for each answer
tokenizer = AutoTokenizer.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa-onnx")
model = ORTModelForSeq2SeqLM.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa-onnx")
with open("updated_context.txt", "r") as f:
context = f.read()
def build_input(question, state=[[],[]]):
model_input = f"{context} || "
previous = min(len(state[1][1:]), N)
for i in range(previous, 0, -1):
prev_question = state[0][-i-1]
prev_answer = state[1][-i]
model_input += f"<Q{i}> {prev_question} <A{i}> {prev_answer} "
model_input += f"<Q> {question} <A> "
return model_input
def get_model_answer(question, state=[[],[]]):
start = time.perf_counter()
model_input = build_input(question, state)
end = time.perf_counter()
print(f"Build input: {end-start}")
start = time.perf_counter()
encoded_inputs = tokenizer(model_input, max_length=7000, truncation=True, return_tensors="pt")
input_ids, attention_mask = (
encoded_inputs.input_ids,
encoded_inputs.attention_mask
)
end = time.perf_counter()
print(f"Tokenize: {end-start}")
start = time.perf_counter()
encoded_output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=MAX_NEW_TOKENS)
answer = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
end = time.perf_counter()
print(f"Generate: {end-start}")
state[0].append(question)
state[1].append(answer)
responses = [(state[0][i], state[1][i]) for i in range(len(state[0]))]
return responses, state
with gr.Blocks() as demo:
state = gr.State([[],[]])
chatbot = gr.Chatbot()
text = gr.Textbox(label="Ask a question (press enter to submit)", value="How are you?")
gr.Examples(
["What's the name of the dataset that was built?", "what task does it focus on?", "what is that task about?"],
text
)
text.submit(get_model_answer, [text, state], [chatbot, state])
text.submit(lambda x: "", text, text)
demo.launch()