wmpscc commited on
Commit
6f161b0
·
1 Parent(s): 60fbce6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -4,14 +4,20 @@ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
4
 
5
  import torch
6
  import gradio as gr
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
8
 
9
 
10
  def init_model():
11
  model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
12
  torch_dtype=torch.bfloat16, trust_remote_code=True)
13
  tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
14
- return model, tokenizer
 
15
 
16
 
17
  def process(message, history):
@@ -19,29 +25,31 @@ def process(message, history):
19
  for interaction in history:
20
  input_prompt = f"{input_prompt} User: {str(interaction[0]).strip(' ')} Bot: {str(interaction[1]).strip(' ')}"
21
  input_prompt = f"{input_prompt} ### Instruction:{message.strip()} ### Response:"
 
22
  inputs = tokenizer(input_prompt, return_tensors="pt").to("cuda:0")
 
 
 
23
  try:
24
- generate_ids = model.generate(inputs.input_ids, max_new_tokens=2048, do_sample=True, top_k=20, top_p=0.84,
25
- temperature=1, repetition_penalty=1.15, eos_token_id=2, bos_token_id=1,
26
- pad_token_id=0)
27
- response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
28
- print('log:', response)
29
- response = response.split("### Response:")[-1]
30
- return response
31
- except:
32
- return "Error: 会话超长,请重试!"
 
33
 
34
 
35
  if __name__ == '__main__':
36
- examples = ["Python和JavaScript编程语言的主要区别是什么?", "影响消费者行为的主要因素是什么?", "请用pytorch实现一个带ReLU激活函数的全连接层的代码",
37
- "请用C++编程语言实现“给你两个字符串haystack和needle,在haystack字符串中找出needle字符串的第一个匹配项的下标(下标从 0 开始)。如果needle不是haystack的一部分,则返回-1。",
38
- "如何使用ssh -L,请用具体例子说明",
39
- "应对压力最有效的方法是什么?"]
40
- model, tokenizer = init_model()
41
  demo = gr.ChatInterface(
42
  process,
43
- chatbot=gr.Chatbot(height=600),
44
- textbox=gr.Textbox(placeholder="Input", container=False, scale=7),
45
  title="Linly ChatFlow",
46
  description="",
47
  theme="soft",
@@ -51,4 +59,4 @@ if __name__ == '__main__':
51
  undo_btn="Delete Previous",
52
  clear_btn="Clear",
53
  )
54
- demo.queue(concurrency_count=75).launch()
 
4
 
5
  import torch
6
  import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
+ from threading import Thread
9
+
10
+ examples = ["Python和JavaScript编程语言的主要区别是什么?", "影响消费者行为的主要因素是什么?", "请用pytorch实现一个带ReLU激活函数的全连接层的代码",
11
+ "请用C++编程语言实现“给你两个字符串haystack和needle,在haystack字符串中找出needle字符串的第一个匹配项的下标(下标从 0 开始)。如果needle不是haystack的一部分,则返回-1。",
12
+ "如何使用ssh -L,请用具体例子说明", "应对压力最有效的方法是什么?"]
13
 
14
 
15
  def init_model():
16
  model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
17
  torch_dtype=torch.bfloat16, trust_remote_code=True)
18
  tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
19
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30.)
20
+ return model, tokenizer, streamer
21
 
22
 
23
  def process(message, history):
 
25
  for interaction in history:
26
  input_prompt = f"{input_prompt} User: {str(interaction[0]).strip(' ')} Bot: {str(interaction[1]).strip(' ')}"
27
  input_prompt = f"{input_prompt} ### Instruction:{message.strip()} ### Response:"
28
+
29
  inputs = tokenizer(input_prompt, return_tensors="pt").to("cuda:0")
30
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048, do_sample=True,
31
+ top_k=20, top_p=0.84, temperature=1.0, repetition_penalty=1.15, eos_token_id=2,
32
+ bos_token_id=1, pad_token_id=0)
33
  try:
34
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
35
+ t.start()
36
+ response = ""
37
+ for text in streamer:
38
+ response += text
39
+ yield response
40
+ print('-log:', response)
41
+ except Exception as e:
42
+ print('-error:', str(e))
43
+ return "Error: 遇到错误,请开启新的会话重新尝试~"
44
 
45
 
46
  if __name__ == '__main__':
47
+ model, tokenizer, streamer = init_model()
48
+
 
 
 
49
  demo = gr.ChatInterface(
50
  process,
51
+ chatbot=gr.Chatbot(height=600, show_label=True, label="Linly"),
52
+ textbox=gr.Textbox(placeholder="Input", container=True, scale=7, lines=3, show_label=False),
53
  title="Linly ChatFlow",
54
  description="",
55
  theme="soft",
 
59
  undo_btn="Delete Previous",
60
  clear_btn="Clear",
61
  )
62
+ demo.queue().launch()