fb700 commited on
Commit
71b517a
·
1 Parent(s): 9638896

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +70 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import streamlit as st
3
+ from streamlit_chat import message
4
+
5
+
6
+ st.set_page_config(
7
+ page_title="帛凡 ChatGLM-6b-fitness-RLHF 演示",
8
+ page_icon=":robot:"
9
+ )
10
+
11
+
12
+ @st.cache_resource
13
+ def get_model():
14
+ tokenizer = AutoTokenizer.from_pretrained("fb700/chatglm-fitness-RLHF/chatglm_rlhf", trust_remote_code=True)
15
+ #model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
16
+ model = AutoModel.from_pretrained("D:\glm\chatglm_webui\chatglm-6b", trust_remote_code=True).quantize(8).half().cuda()
17
+ model = model.eval()
18
+ return tokenizer, model
19
+
20
+
21
+ MAX_TURNS = 20
22
+ MAX_BOXES = MAX_TURNS * 2
23
+
24
+
25
+ def predict(input, max_length, top_p, temperature, history=None):
26
+ tokenizer, model = get_model()
27
+ if history is None:
28
+ history = []
29
+
30
+ with container:
31
+ if len(history) > 0:
32
+ for i, (query, response) in enumerate(history):
33
+ message(query, avatar_style="big-smile", key=str(i) + "_user")
34
+ message(response, avatar_style="bottts", key=str(i))
35
+
36
+ message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
37
+ st.write("AI正在回复:")
38
+ with st.empty():
39
+ for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
40
+ temperature=temperature):
41
+ query, response = history[-1]
42
+ st.write(response)
43
+
44
+ return history
45
+
46
+
47
+ container = st.container()
48
+
49
+ # create a prompt text for the text generation
50
+ prompt_text = st.text_area(label="用户命令输入",
51
+ height = 100,
52
+ placeholder="请在这儿输入您的命令")
53
+
54
+ max_length = st.sidebar.slider(
55
+ 'max_length', 0, 40960, 20480, step=1
56
+ )
57
+ top_p = st.sidebar.slider(
58
+ 'top_p', 0.0, 1.0, 0.6, step=0.01
59
+ )
60
+ temperature = st.sidebar.slider(
61
+ 'temperature', 0.0, 1.0, 0.95, step=0.01
62
+ )
63
+
64
+ if 'state' not in st.session_state:
65
+ st.session_state['state'] = []
66
+
67
+ if st.button("发送", key="predict"):
68
+ with st.spinner("AI正在思考,请稍等........"):
69
+ # text generation
70
+ st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ protobuf
2
+ transformers==4.27.1
3
+ cpm_kernels
4
+ torch>=1.10
5
+ gradio
6
+ mdtex2html
7
+ sentencepiece
8
+ accelerate
9
+ peft