wmpscc commited on
Commit
944357d
·
1 Parent(s): 6f161b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -27
app.py CHANGED
@@ -4,20 +4,14 @@ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
4
 
5
  import torch
6
  import gradio as gr
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
- from threading import Thread
9
-
10
- examples = ["Python和JavaScript编程语言的主要区别是什么?", "影响消费者行为的主要因素是什么?", "请用pytorch实现一个带ReLU激活函数的全连接层的代码",
11
- "请用C++编程语言实现“给你两个字符串haystack和needle,在haystack字符串中找出needle字符串的第一个匹配项的下标(下标从 0 开始)。如果needle不是haystack的一部分,则返回-1。",
12
- "如何使用ssh -L,请用具体例子说明", "应对压力最有效的方法是什么?"]
13
 
14
 
15
  def init_model():
16
  model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
17
  torch_dtype=torch.bfloat16, trust_remote_code=True)
18
  tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
19
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30.)
20
- return model, tokenizer, streamer
21
 
22
 
23
  def process(message, history):
@@ -25,31 +19,29 @@ def process(message, history):
25
  for interaction in history:
26
  input_prompt = f"{input_prompt} User: {str(interaction[0]).strip(' ')} Bot: {str(interaction[1]).strip(' ')}"
27
  input_prompt = f"{input_prompt} ### Instruction:{message.strip()} ### Response:"
28
-
29
  inputs = tokenizer(input_prompt, return_tensors="pt").to("cuda:0")
30
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048, do_sample=True,
31
- top_k=20, top_p=0.84, temperature=1.0, repetition_penalty=1.15, eos_token_id=2,
32
- bos_token_id=1, pad_token_id=0)
33
  try:
34
- t = Thread(target=model.generate, kwargs=generation_kwargs)
35
- t.start()
36
- response = ""
37
- for text in streamer:
38
- response += text
39
- yield response
40
- print('-log:', response)
41
- except Exception as e:
42
- print('-error:', str(e))
43
- return "Error: 遇到错误,请开启新的会话重新尝试~"
44
 
45
 
46
  if __name__ == '__main__':
47
- model, tokenizer, streamer = init_model()
48
-
 
 
 
49
  demo = gr.ChatInterface(
50
  process,
51
- chatbot=gr.Chatbot(height=600, show_label=True, label="Linly"),
52
- textbox=gr.Textbox(placeholder="Input", container=True, scale=7, lines=3, show_label=False),
53
  title="Linly ChatFlow",
54
  description="",
55
  theme="soft",
@@ -59,4 +51,4 @@ if __name__ == '__main__':
59
  undo_btn="Delete Previous",
60
  clear_btn="Clear",
61
  )
62
- demo.queue().launch()
 
4
 
5
  import torch
6
  import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
8
 
9
 
10
  def init_model():
11
  model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
12
  torch_dtype=torch.bfloat16, trust_remote_code=True)
13
  tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
14
+ return model, tokenizer
 
15
 
16
 
17
  def process(message, history):
 
19
  for interaction in history:
20
  input_prompt = f"{input_prompt} User: {str(interaction[0]).strip(' ')} Bot: {str(interaction[1]).strip(' ')}"
21
  input_prompt = f"{input_prompt} ### Instruction:{message.strip()} ### Response:"
 
22
  inputs = tokenizer(input_prompt, return_tensors="pt").to("cuda:0")
 
 
 
23
  try:
24
+ generate_ids = model.generate(inputs.input_ids, max_new_tokens=2048, do_sample=True, top_k=20, top_p=0.84,
25
+ temperature=1, repetition_penalty=1.15, eos_token_id=2, bos_token_id=1,
26
+ pad_token_id=0)
27
+ response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
28
+ print('log:', response)
29
+ response = response.split("### Response:")[-1]
30
+ return response
31
+ except:
32
+ return "Error: 会话超长,请重试!"
 
33
 
34
 
35
  if __name__ == '__main__':
36
+ examples = ["Python和JavaScript编程语言的主要区别是什么?", "影响消费者行为的主要因素是什么?", "请用pytorch实现一个带ReLU激活函数的全连接层的代码",
37
+ "请用C++编程语言实现“给你两个字符串haystack和needle,在haystack字符串中找出needle字符串的第一个匹配项的下标(下标从 0 开始)。如果needle不是haystack的一部分,则返回-1。",
38
+ "如何使用ssh -L,请用具体例子说明",
39
+ "应对压力最有效的方法是什么?"]
40
+ model, tokenizer = init_model()
41
  demo = gr.ChatInterface(
42
  process,
43
+ chatbot=gr.Chatbot(height=600),
44
+ textbox=gr.Textbox(placeholder="Input", container=False, scale=7),
45
  title="Linly ChatFlow",
46
  description="",
47
  theme="soft",
 
51
  undo_btn="Delete Previous",
52
  clear_btn="Clear",
53
  )
54
+ demo.queue(concurrency_count=75).launch()