wangrongsheng commited on
Commit
b62eec7
·
1 Parent(s): 2ff72d7

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +235 -293
  2. model.py +74 -0
  3. requirements.txt +8 -9
  4. style.css +16 -0
app.py CHANGED
@@ -1,329 +1,271 @@
1
- """Credit to https://github.com/THUDM/ChatGLM2-6B/blob/main/web_demo.py while mistakes are mine."""
2
- # pylint: disable=broad-exception-caught, redefined-outer-name, missing-function-docstring, missing-module-docstring, too-many-arguments, line-too-long, invalid-name, redefined-builtin, redefined-argument-from-local
3
- # import gradio as gr
4
 
5
- # model_name = "models/THUDM/chatglm2-6b-int4"
6
- # gr.load(model_name).lauch()
7
 
8
- # %%writefile demo-4bit.py
9
 
10
- import os
11
- import time
12
- from textwrap import dedent
 
 
 
13
 
14
- import gradio as gr
15
- import mdtex2html
16
- import torch
17
- from loguru import logger
18
- from transformers import AutoModel, AutoTokenizer
19
 
20
- # fix timezone in Linux
21
- os.environ["TZ"] = "Asia/Shanghai"
22
- try:
23
- time.tzset() # type: ignore # pylint: disable=no-member
24
- except Exception:
25
- # Windows
26
- logger.warning("Windows, cant run time.tzset()")
27
 
28
- model_name = "wangrongsheng/IvyGPT-35"
29
- #model_name = "OpenMEDLab/PULSE-7bv5"
30
 
31
- RETRY_FLAG = False
 
32
 
33
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
34
- #model = AutoModel.from_pretrained(model_name, trust_remote_code=True).quantize(4).half().cuda()
35
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().cuda()
36
- model = model.eval()
37
 
38
- _ = """Override Chatbot.postprocess"""
 
39
 
40
 
41
- def postprocess(self, y):
42
- if y is None:
43
- return []
44
- for i, (message, response) in enumerate(y):
45
- y[i] = (
46
- None if message is None else mdtex2html.convert((message)),
47
- None if response is None else mdtex2html.convert(response),
48
- )
49
- return y
50
-
51
-
52
- gr.Chatbot.postprocess = postprocess
53
-
54
-
55
- def parse_text(text):
56
- lines = text.split("\n")
57
- lines = [line for line in lines if line != ""]
58
- count = 0
59
- for i, line in enumerate(lines):
60
- if "```" in line:
61
- count += 1
62
- items = line.split("`")
63
- if count % 2 == 1:
64
- lines[i] = f'<pre><code class="language-{items[-1]}">'
65
- else:
66
- lines[i] = "<br></code></pre>"
67
- else:
68
- if i > 0:
69
- if count % 2 == 1:
70
- line = line.replace("`", r"\`")
71
- line = line.replace("<", "&lt;")
72
- line = line.replace(">", "&gt;")
73
- line = line.replace(" ", "&nbsp;")
74
- line = line.replace("*", "&ast;")
75
- line = line.replace("_", "&lowbar;")
76
- line = line.replace("-", "&#45;")
77
- line = line.replace(".", "&#46;")
78
- line = line.replace("!", "&#33;")
79
- line = line.replace("(", "&#40;")
80
- line = line.replace(")", "&#41;")
81
- line = line.replace("$", "&#36;")
82
- lines[i] = "<br>" + line
83
- text = "".join(lines)
84
- return text
85
-
86
-
87
- def predict(
88
- RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
89
- ):
90
- try:
91
- chatbot.append((parse_text(input), ""))
92
- except Exception as exc:
93
- logger.error(exc)
94
- logger.debug(f"{chatbot=}")
95
- _ = """
96
- if chatbot:
97
- chatbot[-1] = (parse_text(input), str(exc))
98
- yield chatbot, history, past_key_values
99
- # """
100
- yield chatbot, history, past_key_values
101
- """
102
- for response, history, past_key_values in model.stream_chat(
103
- tokenizer,
104
- input,
105
- history,
106
- past_key_values=past_key_values,
107
- return_past_key_values=True,
108
- max_length=max_length,
109
- top_p=top_p,
110
- temperature=temperature,
111
- ):
112
- """
113
- for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
114
- temperature=temperature):
115
- chatbot[-1] = (parse_text(input), parse_text(response))
116
-
117
- yield chatbot, history, past_key_values
118
-
119
-
120
- def trans_api(input, max_length=40960, top_p=0.7, temperature=0.95):
121
- if max_length < 10:
122
- max_length = 40960
123
- if top_p < 0.1 or top_p > 1:
124
- top_p = 0.7
125
- if temperature <= 0 or temperature > 1:
126
- temperature = 0.01
127
- try:
128
- res, _ = model.chat(
129
- tokenizer,
130
- input,
131
- history=[],
132
- past_key_values=None,
133
- max_length=max_length,
134
- top_p=top_p,
135
- temperature=temperature,
136
- )
137
- # logger.debug(f"{res=} \n{_=}")
138
- except Exception as exc:
139
- logger.error(f"{exc=}")
140
- res = str(exc)
141
-
142
- return res
143
-
144
-
145
- def reset_user_input():
146
- return gr.update(value="")
147
-
148
-
149
- def reset_state():
150
- return [], [], None
151
-
152
-
153
- # Delete last turn
154
- def delete_last_turn(chat, history):
155
- if chat and history:
156
- chat.pop(-1)
157
- history.pop(-1)
158
- return chat, history
159
-
160
-
161
- # Regenerate response
162
- def retry_last_answer(
163
- user_input, chatbot, max_length, top_p, temperature, history, past_key_values
164
- ):
165
- if chatbot and history:
166
- # Removing the previous conversation from chat
167
- chatbot.pop(-1)
168
- # Setting up a flag to capture a retry
169
- RETRY_FLAG = True
170
- # Getting last message from user
171
- user_input = history[-1][0]
172
- # Removing bot response from the history
173
- history.pop(-1)
174
-
175
- yield from predict(
176
- RETRY_FLAG, # type: ignore
177
- user_input,
178
- chatbot,
179
- max_length,
180
- top_p,
181
- temperature,
182
- history,
183
- past_key_values,
184
- )
185
 
186
 
187
- with gr.Blocks(title="IvyGPT", theme=gr.themes.Soft(text_size="sm")) as demo:
188
- # gr.HTML("""<h1 align="center">ChatGLM2-6B-int4</h1>""")
189
- gr.HTML(
190
- """<h1 align="center">IvyGPT医疗对话大模型</h1>"""
191
- )
192
 
193
- with gr.Accordion("🎈 Info", open=False):
194
- _ = f"""
195
- ## 欢迎体验IvyGPT
196
 
197
- 近期在通用领域中出现的大语言模型(LLMs),例如ChatGPT,在遵循指令和产生类人响应方面表现出了显著的成功。然而,这样的大型语言模型并没有被广泛应用于医学领域,导致响应的准确性较差,无法提供关于医学诊断、药物等合理的建议。IvyGPT是一个医疗大语言模型,它在高质量的医学问答数据上进行了监督微调,并使用人类反馈的强化学习进行了训练。
198
-
199
- [模型下载地址](https://huggingface.co/wangrongsheng/)
200
- """
201
- gr.Markdown(dedent(_))
202
- chatbot = gr.Chatbot()
203
- with gr.Row():
204
- with gr.Column(scale=4):
205
- with gr.Column(scale=12):
206
- user_input = gr.Textbox(
207
- show_label=False,
208
- placeholder="Input...",
209
- ).style(container=False)
210
- RETRY_FLAG = gr.Checkbox(value=False, visible=False)
211
- with gr.Column(min_width=32, scale=1):
212
- with gr.Row():
213
- submitBtn = gr.Button("Submit", variant="primary")
214
- deleteBtn = gr.Button("删除最后一条对话", variant="secondary")
215
- retryBtn = gr.Button("重新生成Regenerate", variant="secondary")
216
- with gr.Column(scale=1):
217
- emptyBtn = gr.Button("Clear History")
218
- max_length = gr.Slider(
219
- 0,
220
- 32768,
221
- value=8192,
222
- step=1.0,
223
- label="Maximum length",
224
- interactive=True,
225
- )
226
- top_p = gr.Slider(
227
- 0, 1, value=0.85, step=0.01, label="Top P", interactive=True
228
- )
229
- temperature = gr.Slider(
230
- 0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- history = gr.State([])
234
- past_key_values = gr.State(None)
 
 
 
 
 
 
 
235
 
236
- user_input.submit(
237
- predict,
238
- [
239
- RETRY_FLAG,
240
- user_input,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  chatbot,
242
- max_length,
243
- top_p,
244
  temperature,
245
- history,
246
- past_key_values,
247
  ],
248
- [chatbot, history, past_key_values],
249
- show_progress="full",
250
  )
251
- submitBtn.click(
252
- predict,
253
- [
254
- RETRY_FLAG,
255
- user_input,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  chatbot,
257
- max_length,
258
- top_p,
259
  temperature,
260
- history,
261
- past_key_values,
262
  ],
263
- [chatbot, history, past_key_values],
264
- show_progress="full",
265
- api_name="predict",
266
- )
267
- submitBtn.click(reset_user_input, [], [user_input])
268
-
269
- emptyBtn.click(
270
- reset_state, outputs=[chatbot, history, past_key_values], show_progress="full"
271
  )
272
 
273
- retryBtn.click(
274
- retry_last_answer,
 
 
 
 
 
 
 
 
 
 
 
 
275
  inputs=[
276
- user_input,
277
  chatbot,
278
- max_length,
279
- top_p,
280
  temperature,
281
- history,
282
- past_key_values,
283
  ],
284
- # outputs = [chatbot, history, last_user_message, user_message]
285
- outputs=[chatbot, history, past_key_values],
286
  )
287
- deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
288
-
289
- with gr.Accordion("Example inputs", open=True):
290
- examples = gr.Examples(
291
- examples=[
292
- ["熬夜对身体有什么危害? "],
293
- ["新冠肺炎怎么预防"],
294
- ["系统性红斑狼疮的危害和治疗方法是什么?"],
295
- ],
296
- inputs=[user_input],
297
- examples_per_page=50,
298
- )
299
 
300
- with gr.Accordion("For Chat/Translation API", open=False, visible=False):
301
- input_text = gr.Text()
302
- tr_btn = gr.Button("Go", variant="primary")
303
- out_text = gr.Text()
304
- tr_btn.click(
305
- trans_api,
306
- [input_text, max_length, top_p, temperature],
307
- out_text,
308
- # show_progress="full",
309
- api_name="tr",
310
- )
311
- _ = """
312
- input_text.submit(
313
- trans_api,
314
- [input_text, max_length, top_p, temperature],
315
- out_text,
316
- show_progress="full",
317
- api_name="tr1",
318
  )
319
- # """
320
 
321
- # demo.queue().launch(share=False, inbrowser=True)
322
- # demo.queue().launch(share=True, inbrowser=True, debug=True)
323
-
324
- # concurrency_count > 1 requires more memory, max_size: queue size
325
- # T4 medium: 30GB, model size: ~4G concurrency_count = 6
326
- # leave one for api access
327
- # reduce to 5 if OOM occurs to often
328
 
329
- demo.queue(concurrency_count=3, max_size=30).launch(debug=True)
 
1
+ from typing import Iterator
 
 
2
 
3
+ import gradio as gr
4
+ import torch
5
 
6
+ from model import get_input_token_length, run
7
 
8
+ DEFAULT_SYSTEM_PROMPT = """\
9
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
10
+ """
11
+ MAX_MAX_NEW_TOKENS = 2048
12
+ DEFAULT_MAX_NEW_TOKENS = 1024
13
+ MAX_INPUT_TOKEN_LENGTH = 4000
14
 
15
+ DESCRIPTION = """
16
+ # CareLlama-关怀羊驼
 
 
 
17
 
18
+ - CareLlama (关怀羊驼)是一个医疗大语言模型,同时它集合了数十个公开可用的医疗微调数据集和开放可用的医疗大语言模型以促进医疗LLM快速发展。
19
+ - Medical LLM, Open Source Driven for a Healthy Future.
 
 
 
 
 
20
 
21
+ """
 
22
 
23
+ LICENSE = """
24
+ <p/>
25
 
26
+ ---
27
+ 本项目相关资源仅供学术研究之用,严禁用于商业用途。使用涉及第三方代码的部分时,请严格遵循相应的开源协议。模型生成的内容受模型计算、随机性和量化精度损失等因素影响,本项目无法对其准确性作出保证。即使本项目模型输出符合医学事实,也不能被用作实际医学诊断的依据。对于模型输出的任何内容,本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。
28
+ """
 
29
 
30
+ if not torch.cuda.is_available():
31
+ DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
32
 
33
 
34
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
35
+ return '', message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
+ def display_input(message: str,
39
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
40
+ history.append((message, ''))
41
+ return history
 
42
 
 
 
 
43
 
44
+ def delete_prev_fn(
45
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
46
+ try:
47
+ message, _ = history.pop()
48
+ except IndexError:
49
+ message = ''
50
+ return history, message or ''
51
+
52
+
53
+ def generate(
54
+ message: str,
55
+ history_with_input: list[tuple[str, str]],
56
+ system_prompt: str,
57
+ max_new_tokens: int,
58
+ temperature: float,
59
+ top_p: float,
60
+ top_k: int,
61
+ ) -> Iterator[list[tuple[str, str]]]:
62
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
63
+ raise ValueError
64
+
65
+ history = history_with_input[:-1]
66
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
67
+ try:
68
+ first_response = next(generator)
69
+ yield history + [(message, first_response)]
70
+ except StopIteration:
71
+ yield history + [(message, '')]
72
+ for response in generator:
73
+ yield history + [(message, response)]
74
+
75
+
76
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
77
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
78
+ for x in generator:
79
+ pass
80
+ return '', x
81
+
82
+
83
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
84
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
85
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
86
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
87
+
88
+
89
+ with gr.Blocks(css='style.css') as demo:
90
+ gr.Markdown(DESCRIPTION)
91
+ gr.DuplicateButton(value='Duplicate Space for private use',
92
+ elem_id='duplicate-button')
93
+
94
+ with gr.Group():
95
+ chatbot = gr.Chatbot(label='CareLlama')
96
+ with gr.Row():
97
+ textbox = gr.Textbox(
98
+ container=False,
99
+ show_label=False,
100
+ placeholder='请输入内容...',
101
+ scale=10,
102
  )
103
+ submit_button = gr.Button('Submit',
104
+ variant='primary',
105
+ scale=1,
106
+ min_width=0)
107
+ with gr.Row():
108
+ retry_button = gr.Button('🔄 重试', variant='secondary')
109
+ undo_button = gr.Button('↩️ 撤销', variant='secondary')
110
+ clear_button = gr.Button('🗑️ 清除', variant='secondary')
111
+
112
+ saved_input = gr.State()
113
+
114
+ with gr.Accordion(label='Advanced options', open=False):
115
+ system_prompt = gr.Textbox(label='System prompt',
116
+ value=DEFAULT_SYSTEM_PROMPT,
117
+ lines=6)
118
+ max_new_tokens = gr.Slider(
119
+ label='Max new tokens',
120
+ minimum=1,
121
+ maximum=MAX_MAX_NEW_TOKENS,
122
+ step=1,
123
+ value=DEFAULT_MAX_NEW_TOKENS,
124
+ )
125
+ temperature = gr.Slider(
126
+ label='Temperature',
127
+ minimum=0.1,
128
+ maximum=4.0,
129
+ step=0.1,
130
+ value=1.0,
131
+ )
132
+ top_p = gr.Slider(
133
+ label='Top-p (nucleus sampling)',
134
+ minimum=0.05,
135
+ maximum=1.0,
136
+ step=0.05,
137
+ value=0.95,
138
+ )
139
+ top_k = gr.Slider(
140
+ label='Top-k',
141
+ minimum=1,
142
+ maximum=1000,
143
+ step=1,
144
+ value=50,
145
+ )
146
 
147
+ gr.Examples(
148
+ examples=[
149
+ '你好'
150
+ ],
151
+ inputs=textbox,
152
+ outputs=[textbox, chatbot],
153
+ fn=process_example,
154
+ cache_examples=True,
155
+ )
156
 
157
+ gr.Markdown(LICENSE)
158
+
159
+ textbox.submit(
160
+ fn=clear_and_save_textbox,
161
+ inputs=textbox,
162
+ outputs=[textbox, saved_input],
163
+ api_name=False,
164
+ queue=False,
165
+ ).then(
166
+ fn=display_input,
167
+ inputs=[saved_input, chatbot],
168
+ outputs=chatbot,
169
+ api_name=False,
170
+ queue=False,
171
+ ).then(
172
+ fn=check_input_token_length,
173
+ inputs=[saved_input, chatbot, system_prompt],
174
+ api_name=False,
175
+ queue=False,
176
+ ).success(
177
+ fn=generate,
178
+ inputs=[
179
+ saved_input,
180
  chatbot,
181
+ system_prompt,
182
+ max_new_tokens,
183
  temperature,
184
+ top_p,
185
+ top_k,
186
  ],
187
+ outputs=chatbot,
188
+ api_name=False,
189
  )
190
+
191
+ button_event_preprocess = submit_button.click(
192
+ fn=clear_and_save_textbox,
193
+ inputs=textbox,
194
+ outputs=[textbox, saved_input],
195
+ api_name=False,
196
+ queue=False,
197
+ ).then(
198
+ fn=display_input,
199
+ inputs=[saved_input, chatbot],
200
+ outputs=chatbot,
201
+ api_name=False,
202
+ queue=False,
203
+ ).then(
204
+ fn=check_input_token_length,
205
+ inputs=[saved_input, chatbot, system_prompt],
206
+ api_name=False,
207
+ queue=False,
208
+ ).success(
209
+ fn=generate,
210
+ inputs=[
211
+ saved_input,
212
  chatbot,
213
+ system_prompt,
214
+ max_new_tokens,
215
  temperature,
216
+ top_p,
217
+ top_k,
218
  ],
219
+ outputs=chatbot,
220
+ api_name=False,
 
 
 
 
 
 
221
  )
222
 
223
+ retry_button.click(
224
+ fn=delete_prev_fn,
225
+ inputs=chatbot,
226
+ outputs=[chatbot, saved_input],
227
+ api_name=False,
228
+ queue=False,
229
+ ).then(
230
+ fn=display_input,
231
+ inputs=[saved_input, chatbot],
232
+ outputs=chatbot,
233
+ api_name=False,
234
+ queue=False,
235
+ ).then(
236
+ fn=generate,
237
  inputs=[
238
+ saved_input,
239
  chatbot,
240
+ system_prompt,
241
+ max_new_tokens,
242
  temperature,
243
+ top_p,
244
+ top_k,
245
  ],
246
+ outputs=chatbot,
247
+ api_name=False,
248
  )
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ undo_button.click(
251
+ fn=delete_prev_fn,
252
+ inputs=chatbot,
253
+ outputs=[chatbot, saved_input],
254
+ api_name=False,
255
+ queue=False,
256
+ ).then(
257
+ fn=lambda x: x,
258
+ inputs=[saved_input],
259
+ outputs=textbox,
260
+ api_name=False,
261
+ queue=False,
 
 
 
 
 
 
262
  )
 
263
 
264
+ clear_button.click(
265
+ fn=lambda: ([], ''),
266
+ outputs=[chatbot, saved_input],
267
+ queue=False,
268
+ api_name=False,
269
+ )
 
270
 
271
+ demo.queue(max_size=20).launch()
model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from typing import Iterator
3
+
4
+ import torch
5
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+
7
+ model_id = '../merge'
8
+
9
+ if torch.cuda.is_available():
10
+ config = AutoConfig.from_pretrained(model_id)
11
+ config.pretraining_tp = 1
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_id,
14
+ config=config,
15
+ torch_dtype=torch.float16,
16
+ load_in_4bit=True,
17
+ device_map='auto'
18
+ )
19
+ else:
20
+ model = None
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+
23
+
24
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
25
+ system_prompt: str) -> str:
26
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
27
+ # The first user input is _not_ stripped
28
+ do_strip = False
29
+ for user_input, response in chat_history:
30
+ user_input = user_input.strip() if do_strip else user_input
31
+ do_strip = True
32
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
33
+ message = message.strip() if do_strip else message
34
+ texts.append(f'{message} [/INST]')
35
+ return ''.join(texts)
36
+
37
+
38
+ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
39
+ prompt = get_prompt(message, chat_history, system_prompt)
40
+ input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
41
+ return input_ids.shape[-1]
42
+
43
+
44
+ def run(message: str,
45
+ chat_history: list[tuple[str, str]],
46
+ system_prompt: str,
47
+ max_new_tokens: int = 1024,
48
+ temperature: float = 0.8,
49
+ top_p: float = 0.95,
50
+ top_k: int = 50) -> Iterator[str]:
51
+ prompt = get_prompt(message, chat_history, system_prompt)
52
+ inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
53
+
54
+ streamer = TextIteratorStreamer(tokenizer,
55
+ timeout=10.,
56
+ skip_prompt=True,
57
+ skip_special_tokens=True)
58
+ generate_kwargs = dict(
59
+ inputs,
60
+ streamer=streamer,
61
+ max_new_tokens=max_new_tokens,
62
+ do_sample=True,
63
+ top_p=top_p,
64
+ top_k=top_k,
65
+ temperature=temperature,
66
+ num_beams=1,
67
+ )
68
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
69
+ t.start()
70
+
71
+ outputs = []
72
+ for text in streamer:
73
+ outputs.append(text)
74
+ yield ''.join(outputs)
requirements.txt CHANGED
@@ -1,9 +1,8 @@
1
- protobuf
2
- transformers==4.30.2
3
- cpm_kernels
4
- torch>=2.0
5
- gradio
6
- mdtex2html
7
- sentencepiece
8
- accelerate
9
- loguru
 
1
+ accelerate==0.21.0
2
+ bitsandbytes==0.40.2
3
+ gradio==3.37.0
4
+ protobuf==3.20.3
5
+ scipy==1.11.1
6
+ sentencepiece==0.1.99
7
+ torch==2.0.1
8
+ transformers==4.31.0
 
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ #component-0 {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }