File size: 2,517 Bytes
c152a6e
92360e8
c152a6e
 
 
 
 
 
 
 
 
 
92360e8
 
 
 
c152a6e
 
 
92360e8
805fbb1
92360e8
 
c152a6e
92360e8
c152a6e
92360e8
c152a6e
 
92360e8
 
 
 
c152a6e
92360e8
c152a6e
92360e8
 
 
3bac1fd
92360e8
 
 
 
 
 
 
c152a6e
92360e8
 
 
c152a6e
 
 
 
 
 
 
 
26c120c
965be30
e414796
965be30
 
c152a6e
 
 
92360e8
c152a6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from transformers import AutoTokenizer
import time
import gradio as gr
from optimum.onnxruntime import ORTModelForSeq2SeqLM
from optimum.utils import NormalizedConfigManager

@classmethod
def _new_get_normalized_config_class(cls, model_type):
    return cls._conf["t5"]

NormalizedConfigManager.get_normalized_config_class = _new_get_normalized_config_class


N = 2 # Number of previous QA pairs to use for context
MAX_NEW_TOKENS = 128 # Maximum number of tokens for each answer

tokenizer = AutoTokenizer.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa-onnx")
model = ORTModelForSeq2SeqLM.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa-onnx")


with open("updated_context.txt", "r") as f:
    context = f.read()

def build_input(question, state=[[],[]]):
    model_input = f"{context} || "
    previous = min(len(state[1][1:]), N)
    for i in range(previous, 0, -1):
        prev_question = state[0][-i-1]
        prev_answer = state[1][-i]
        model_input += f"<Q{i}> {prev_question} <A{i}> {prev_answer} "
    model_input += f"<Q> {question} <A> "
    return model_input

def get_model_answer(question, state=[[],[]]):
    start = time.perf_counter()
    model_input = build_input(question, state)
    end = time.perf_counter()
    print(f"Build input: {end-start}")
    start = time.perf_counter()
    encoded_inputs = tokenizer(model_input, max_length=7000, truncation=True, return_tensors="pt")
    input_ids, attention_mask = (
        encoded_inputs.input_ids,
        encoded_inputs.attention_mask
    )
    end = time.perf_counter()
    print(f"Tokenize: {end-start}")
    start = time.perf_counter()
    encoded_output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=MAX_NEW_TOKENS)
    answer = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
    end = time.perf_counter()
    print(f"Generate: {end-start}")
    state[0].append(question)
    state[1].append(answer)
    responses = [(state[0][i], state[1][i]) for i in range(len(state[0]))]
    return responses, state

with gr.Blocks() as demo:
    state = gr.State([[],[]])
    chatbot = gr.Chatbot()
    text = gr.Textbox(label="Ask a question (press enter to submit)", value="How are you?")
    gr.Examples(
        ["What's the name of the dataset that was built?", "what task does it focus on?", "what is that task about?"],
        text
    )

    text.submit(get_model_answer, [text, state], [chatbot, state])
    text.submit(lambda x: "", text, text)

demo.launch()