rgallardo commited on
Commit
c152a6e
·
1 Parent(s): 92360e8

Create chatbot interface

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +34 -13
  3. requirements.txt +7 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ onnx
app.py CHANGED
@@ -1,28 +1,39 @@
1
- from transformers import LongT5ForConditionalGeneration, AutoTokenizer
2
  import time
 
 
 
 
 
 
 
 
 
 
3
 
4
  N = 2 # Number of previous QA pairs to use for context
5
  MAX_NEW_TOKENS = 128 # Maximum number of tokens for each answer
6
 
7
- tokenizer = AutoTokenizer.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa")
8
- model = LongT5ForConditionalGeneration.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa")
 
9
 
10
  with open("context_short.txt", "r") as f:
11
  context = f.read()
12
 
13
- def build_input(question, user_history=[], bot_history=[]):
14
  model_input = f"{context} || "
15
- previous = min(len(bot_history[1:]), N)
16
  for i in range(previous, 0, -1):
17
- prev_question = user_history[-i-1]
18
- prev_answer = bot_history[-i]
19
  model_input += f"<Q{i}> {prev_question} <A{i}> {prev_answer} "
20
  model_input += f"<Q> {question} <A> "
21
  return model_input
22
 
23
- def get_model_answer(question, user_history=[], bot_history=[]):
24
  start = time.perf_counter()
25
- model_input = build_input(question, user_history, bot_history)
26
  end = time.perf_counter()
27
  print(f"Build input: {end-start}")
28
  start = time.perf_counter()
@@ -34,11 +45,21 @@ def get_model_answer(question, user_history=[], bot_history=[]):
34
  end = time.perf_counter()
35
  print(f"Tokenize: {end-start}")
36
  start = time.perf_counter()
37
- encoded_output = model.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=True, max_new_tokens=MAX_NEW_TOKENS)
38
  answer = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
39
  end = time.perf_counter()
40
  print(f"Generate: {end-start}")
41
- user_history.append(question)
42
- bot_history.append(answer)
43
- return answer, user_history, bot_history
 
 
 
 
 
 
 
 
 
44
 
 
 
1
+ from transformers import AutoTokenizer
2
  import time
3
+ import gradio as gr
4
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM
5
+ from optimum.utils import NormalizedConfigManager
6
+
7
+ @classmethod
8
+ def _new_get_normalized_config_class(cls, model_type):
9
+ return cls._conf["t5"]
10
+
11
+ NormalizedConfigManager.get_normalized_config_class = _new_get_normalized_config_class
12
+
13
 
14
  N = 2 # Number of previous QA pairs to use for context
15
  MAX_NEW_TOKENS = 128 # Maximum number of tokens for each answer
16
 
17
+ tokenizer = AutoTokenizer.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa-onnx")
18
+ model = ORTModelForSeq2SeqLM.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa-onnx")
19
+
20
 
21
  with open("context_short.txt", "r") as f:
22
  context = f.read()
23
 
24
+ def build_input(question, state=[[],[]]):
25
  model_input = f"{context} || "
26
+ previous = min(len(state[1][1:]), N)
27
  for i in range(previous, 0, -1):
28
+ prev_question = state[0][-i-1]
29
+ prev_answer = state[1][-i]
30
  model_input += f"<Q{i}> {prev_question} <A{i}> {prev_answer} "
31
  model_input += f"<Q> {question} <A> "
32
  return model_input
33
 
34
+ def get_model_answer(question, state=[[],[]]):
35
  start = time.perf_counter()
36
+ model_input = build_input(question, state)
37
  end = time.perf_counter()
38
  print(f"Build input: {end-start}")
39
  start = time.perf_counter()
 
45
  end = time.perf_counter()
46
  print(f"Tokenize: {end-start}")
47
  start = time.perf_counter()
48
+ encoded_output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=MAX_NEW_TOKENS)
49
  answer = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
50
  end = time.perf_counter()
51
  print(f"Generate: {end-start}")
52
+ state[0].append(question)
53
+ state[1].append(answer)
54
+ responses = [(state[0][i], state[1][i]) for i in range(len(state[0]))]
55
+ return responses, state
56
+
57
+ with gr.Blocks() as demo:
58
+ state = gr.State([[],[]])
59
+ chatbot = gr.Chatbot()
60
+ text = gr.Textbox(label="Ask a question (press enter to submit)", default_value="How are you?")
61
+
62
+ text.submit(get_model_answer, [text, state], [chatbot, state])
63
+ text.submit(lambda x: "", text, text)
64
 
65
+ demo.launch()
requirements.txt CHANGED
@@ -1,2 +1,8 @@
1
  transformers
2
- torch
 
 
 
 
 
 
 
1
  transformers
2
+ torch
3
+ onnx==1.12.0
4
+ onnxconverter-common==1.13.0
5
+ onnxruntime==1.13.1
6
+ onnxruntime-tools==1.7.0
7
+ openvino==2022.2.0
8
+ optimum @ git+https://github.com/huggingface/optimum.git@4c3b1c14f07c8e3780d9c9765b3992a90fab3349