WuChengyue commited on
Commit
3dfbcaf
·
1 Parent(s): 763f1c5

first commit

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import sys
4
+ import html
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
+ from threading import Thread
7
+
8
+ model_name_or_path = 'TencentARC/LLaMA-Pro-8B-Instruct'
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
12
+
13
+ model.half().cuda()
14
+
15
+ def convert_message(message):
16
+ message_text = ""
17
+ if message["content"] is None and message["role"] == "assistant":
18
+ message_text += "<|assistant|>\n" # final msg
19
+ elif message["role"] == "system":
20
+ message_text += "<|system|>\n" + message["content"].strip() + "\n"
21
+ elif message["role"] == "user":
22
+ message_text += "<|user|>\n" + message["content"].strip() + "\n"
23
+ elif message["role"] == "assistant":
24
+ message_text += "<|assistant|>\n" + message["content"].strip() + "\n"
25
+ else:
26
+ raise ValueError("Invalid role: {}".format(message["role"]))
27
+ # gradio cleaning - it converts stuff to html entities
28
+ # we would need special handling for where we want to keep the html...
29
+ message_text = html.unescape(message_text)
30
+ # it also converts newlines to <br>, undo this.
31
+ message_text = message_text.replace("<br>", "\n")
32
+ return message_text
33
+
34
+ def convert_history(chat_history, max_input_length=1024):
35
+ history_text = ""
36
+ idx = len(chat_history) - 1
37
+ # add messages in reverse order until we hit max_input_length
38
+ while len(tokenizer(history_text).input_ids) < max_input_length and idx >= 0:
39
+ user_message, chatbot_message = chat_history[idx]
40
+ user_message = convert_message({"role": "user", "content": user_message})
41
+ chatbot_message = convert_message({"role": "assistant", "content": chatbot_message})
42
+ history_text = user_message + chatbot_message + history_text
43
+ idx = idx - 1
44
+ # if nothing was added, add <|assistant|> to start generation.
45
+ if history_text == "":
46
+ history_text = "<|assistant|>\n"
47
+ return history_text
48
+
49
+ @torch.inference_mode()
50
+ def instruct(instruction, max_token_output=1024):
51
+ input_text = instruction
52
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
53
+ input_ids = tokenizer(input_text, return_tensors='pt', truncation=False)
54
+ input_ids["input_ids"] = input_ids["input_ids"].cuda()
55
+ input_ids["attention_mask"] = input_ids["attention_mask"].cuda()
56
+ generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=max_token_output)
57
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
58
+ thread.start()
59
+ return streamer
60
+
61
+
62
+ with gr.Blocks() as demo:
63
+ # recreating the original qa demo in blocks
64
+ with gr.Tab("QA Demo"):
65
+ with gr.Row():
66
+ instruction = gr.Textbox(label="Input")
67
+ output = gr.Textbox(label="Output")
68
+ greet_btn = gr.Button("Submit")
69
+ def yield_instruct(instruction):
70
+ # quick prompt hack:
71
+ instruction = "<|user|>\n" + instruction + "\n<|assistant|>\n"
72
+ output = ""
73
+ for token in instruct(instruction):
74
+ output += token
75
+ yield output
76
+ greet_btn.click(fn=yield_instruct, inputs=[instruction], outputs=output, api_name="greet")
77
+ # chatbot-style model
78
+ with gr.Tab("Chatbot"):
79
+ chatbot = gr.Chatbot([], elem_id="chatbot")
80
+ msg = gr.Textbox()
81
+ clear = gr.Button("Clear")
82
+ # fn to add user message to history
83
+ def user(user_message, history):
84
+ return "", history + [[user_message, None]]
85
+
86
+ def bot(history):
87
+ prompt = convert_history(history)
88
+ streaming_out = instruct(prompt)
89
+ history[-1][1] = ""
90
+ for new_token in streaming_out:
91
+ history[-1][1] += new_token
92
+ yield history
93
+
94
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
95
+ bot, chatbot, chatbot
96
+ )
97
+
98
+ clear.click(lambda: None, None, chatbot, queue=False)
99
+
100
+ if __name__ == "__main__":
101
+ demo.queue().launch(share=True)