Spaces:
Runtime error
Runtime error
jhsheng
commited on
Commit
·
bc4af1f
1
Parent(s):
0b5042b
init
Browse files- app.py +158 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 South China University of Technology and
|
3 |
+
# Engineering Research Ceter of Ministry of Education on Human Body Perception.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
# Author: Chen Yirong <[email protected]>
|
19 |
+
# Date: 2023.06.07
|
20 |
+
|
21 |
+
''' 运行方式
|
22 |
+
```bash
|
23 |
+
pip install streamlit # 第一次运行需要安装streamlit
|
24 |
+
pip install streamlit_chat # 第一次运行需要安装streamlit_chat
|
25 |
+
streamlit run bianque_v2_app.py --server.port 9005
|
26 |
+
```
|
27 |
+
|
28 |
+
## 测试访问
|
29 |
+
|
30 |
+
http://<your_ip>:9005
|
31 |
+
|
32 |
+
'''
|
33 |
+
|
34 |
+
|
35 |
+
import os
|
36 |
+
import torch
|
37 |
+
import streamlit as st
|
38 |
+
from streamlit_chat import message
|
39 |
+
from transformers import AutoModel, AutoTokenizer
|
40 |
+
|
41 |
+
|
42 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 默认使用0号显卡,避免Windows用户忘记修改该处
|
43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
+
|
45 |
+
# 指定模型名称或路径
|
46 |
+
model_name_or_path = "scutcyr/BianQue-2"
|
47 |
+
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
49 |
+
|
50 |
+
|
51 |
+
def answer(user_history, bot_history, sample=True, top_p=0.7, temperature=0.95):
|
52 |
+
'''sample:是否抽样。生成任务,可以设置为True;
|
53 |
+
top_p=0.7, temperature=0.95时的生成效果较好
|
54 |
+
top_p=1, temperature=0.7时提问能力会提升
|
55 |
+
top_p:0-1之间,生成的内容越多样
|
56 |
+
max_new_tokens=512 lost...'''
|
57 |
+
|
58 |
+
if len(bot_history)>0:
|
59 |
+
context = "\n".join([f"病人:{user_history[i]}\n医生:{bot_history[i]}" for i in range(len(bot_history))])
|
60 |
+
input_text = context + "\n病人:" + user_history[-1] + "\n医生:"
|
61 |
+
else:
|
62 |
+
input_text = "病人:" + user_history[-1] + "\n医生:"
|
63 |
+
#if user_history[-1] =="你好" or user_history[-1] =="你好!":
|
64 |
+
return "我是利用人工智能技术,结合大数据训练得到的智能医疗问答模型扁鹊,你可以向我提问。"
|
65 |
+
#return "我是生活空间健康对话大模型扁鹊,欢迎向我提问。"
|
66 |
+
|
67 |
+
print(input_text)
|
68 |
+
|
69 |
+
if not sample:
|
70 |
+
response, history = model.chat(tokenizer, query=input_text, history=None, max_length=2048, num_beams=1, do_sample=False, top_p=top_p, temperature=temperature, logits_processor=None)
|
71 |
+
else:
|
72 |
+
response, history = model.chat(tokenizer, query=input_text, history=None, max_length=2048, num_beams=1, do_sample=True, top_p=top_p, temperature=temperature, logits_processor=None)
|
73 |
+
|
74 |
+
print('医生: '+response)
|
75 |
+
|
76 |
+
return response
|
77 |
+
|
78 |
+
st.set_page_config(
|
79 |
+
page_title="扁鹊健康大模型(BianQue-2.0)",
|
80 |
+
page_icon="🧊",
|
81 |
+
layout="wide",
|
82 |
+
initial_sidebar_state="expanded",
|
83 |
+
menu_items={
|
84 |
+
'About': """
|
85 |
+
- 版本:扁鹊健康大模型(BianQue) V2.0.0 Beta
|
86 |
+
- 机构:广东省数字孪生人重点实验室
|
87 |
+
- 作者:陈艺荣、王振宇、徐志沛、方凱、李思航、王骏宏、邢晓芬、徐向民
|
88 |
+
"""
|
89 |
+
}
|
90 |
+
)
|
91 |
+
|
92 |
+
st.header("扁鹊健康大模型(BianQue-2.0)")
|
93 |
+
|
94 |
+
with st.expander("ℹ️ - 关于我们", expanded=False):
|
95 |
+
st.write(
|
96 |
+
"""
|
97 |
+
- 版本:扁鹊健康大模型(BianQue) V2.0.0 Beta
|
98 |
+
- 机构:广东省数字孪生人重点实验室
|
99 |
+
- 作者:陈艺荣、王振宇、徐志沛、方凱、李思航、王骏宏、邢晓芬、徐向民
|
100 |
+
"""
|
101 |
+
)
|
102 |
+
|
103 |
+
# https://docs.streamlit.io/library/api-reference/performance/st.cache_resource
|
104 |
+
|
105 |
+
@st.cache_resource
|
106 |
+
def load_model():
|
107 |
+
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half()
|
108 |
+
model.to(device)
|
109 |
+
print('Model Load done!')
|
110 |
+
return model
|
111 |
+
|
112 |
+
@st.cache_resource
|
113 |
+
def load_tokenizer():
|
114 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
115 |
+
print('Tokenizer Load done!')
|
116 |
+
return tokenizer
|
117 |
+
|
118 |
+
model = load_model()
|
119 |
+
tokenizer = load_tokenizer()
|
120 |
+
|
121 |
+
if 'generated' not in st.session_state:
|
122 |
+
st.session_state['generated'] = []
|
123 |
+
|
124 |
+
if 'past' not in st.session_state:
|
125 |
+
st.session_state['past'] = []
|
126 |
+
|
127 |
+
|
128 |
+
user_col, ensure_col = st.columns([5, 1])
|
129 |
+
|
130 |
+
def get_text():
|
131 |
+
input_text = user_col.text_area("请在下列文本框输入您的咨询内容:","", key="input", placeholder="请输入您的咨询内容,并且点击Ctrl+Enter(或者发送按钮)确认内容")
|
132 |
+
if ensure_col.button("发送", use_container_width=True):
|
133 |
+
if input_text:
|
134 |
+
return input_text
|
135 |
+
|
136 |
+
user_input = get_text()
|
137 |
+
|
138 |
+
if user_input:
|
139 |
+
st.session_state.past.append(user_input)
|
140 |
+
output = answer(st.session_state['past'],st.session_state["generated"])
|
141 |
+
st.session_state.generated.append(output)
|
142 |
+
|
143 |
+
if st.session_state['generated']:
|
144 |
+
for i in range(len(st.session_state['generated'])):
|
145 |
+
if i == 0:
|
146 |
+
#
|
147 |
+
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user', avatar_style="avataaars", seed=26)
|
148 |
+
message(st.session_state["generated"][i]+"\n\n------------------\n以下回答由扁鹊健康模型自动生成,仅供参考!", key=str(i), avatar_style="avataaars", seed=5)
|
149 |
+
else:
|
150 |
+
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user', avatar_style="avataaars", seed=26)
|
151 |
+
#message(st.session_state["generated"][i], key=str(i))
|
152 |
+
message(st.session_state["generated"][i], key=str(i), avatar_style="avataaars", seed=5)
|
153 |
+
|
154 |
+
if st.button("清理对话缓存"):
|
155 |
+
# Clear values from *all* all in-memory and on-disk data caches:
|
156 |
+
# i.e. clear values from both square and cube
|
157 |
+
st.session_state['generated'] = []
|
158 |
+
st.session_state['past'] = []
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
protobuf
|
2 |
+
transformers==4.28.0
|
3 |
+
cpm_kernels
|
4 |
+
torch>=1.10
|
5 |
+
gradio
|
6 |
+
mdtex2html
|
7 |
+
sentencepiece
|
8 |
+
accelerate
|