dreamerdeo commited on
Commit
f61d72c
1 Parent(s): e88da0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -29
app.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
  from threading import Thread
6
 
7
- # model_path = 'dreamerdeo/Sailor2-0.8B-Chat'
8
- model_path = 'sail/Sailor-0.5B-Chat'
9
 
10
  # Loading the tokenizer and model from Hugging Face's model hub.
11
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@@ -78,33 +78,15 @@ def predict(message, history):
78
  repetition_penalty=1.1,
79
  )
80
 
81
- outputs = model.generate(**generate_kwargs)
82
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
- partial_message = generated_text
84
- final_message = partial_message.replace(sft_end_token, "").strip()
85
-
86
- return final_message
87
-
88
- # # 使用线程来运行生成过程
89
- # t = Thread(target=model.generate, kwargs=generate_kwargs)
90
- # t.start()
91
-
92
- # # 实时生成部分消息
93
- # partial_message = ""
94
- # for new_token in streamer:
95
- # partial_message += new_token
96
- # if sft_end_token in partial_message: # 检测到停止标志
97
- # break
98
- # # 将历史记录和当前消息转换为 tuple 格式并实时返回
99
- # # yield [(msg, bot) for msg, bot in history] + [(message, partial_message)]
100
- # # yield (message, partial_message)
101
- # yield partial_message
102
-
103
- # # 处理生成的最终回复
104
- # final_message = partial_message.replace(sft_end_token, "").strip()
105
- # history.append([message, final_message]) # 更新历史记录
106
- # # 返回最终的对话历史,确保格式为元组的列表
107
- # yield [(msg, bot) for msg, bot in history]
108
 
109
  css = """
110
  full-height {
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
  from threading import Thread
6
 
7
+ model_path = 'dreamerdeo/Sailor2-0.8B-Chat'
8
+ # model_path = 'sail/Sailor-0.5B-Chat'
9
 
10
  # Loading the tokenizer and model from Hugging Face's model hub.
11
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
78
  repetition_penalty=1.1,
79
  )
80
 
81
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
82
+ t.start() # Starting the generation in a separate thread.
83
+ partial_message = ""
84
+ for new_token in streamer:
85
+ partial_message += new_token
86
+ if sft_end_token in partial_message: # Breaking the loop if the stop token is generated.
87
+ break
88
+ yield partial_message
89
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  css = """
92
  full-height {