Limour commited on
Commit
d3399cd
·
verified ·
1 Parent(s): 0523803

Delete sub_app.py

Browse files
Files changed (1) hide show
  1. sub_app.py +0 -496
sub_app.py DELETED
@@ -1,496 +0,0 @@
1
- import hashlib
2
- import os
3
- import re
4
- import json
5
- import threading
6
- from hf_api import restart_space
7
-
8
- import gradio as gr
9
-
10
- from chat_template import ChatTemplate
11
- from llama_cpp_python_streamingllm import StreamingLLM
12
-
13
- # ========== 全局锁,确保只能进行一个会话 ==========
14
- lock = threading.Lock()
15
- session_active = False
16
-
17
- # ========== 让聊天界面的文本框等高 ==========
18
- custom_css = r'''
19
- #area > div {
20
- height: 100%;
21
- }
22
- #RAG-area {
23
- flex-grow: 1;
24
- }
25
- #RAG-area > label {
26
- height: 100%;
27
- display: flex;
28
- flex-direction: column;
29
- }
30
- #RAG-area > label > textarea {
31
- flex-grow: 1;
32
- max-height: 20vh;
33
- }
34
- #VO-area {
35
- flex-grow: 1;
36
- }
37
- #VO-area > label {
38
- height: 100%;
39
- display: flex;
40
- flex-direction: column;
41
- }
42
- #VO-area > label > textarea {
43
- flex-grow: 1;
44
- max-height: 20vh;
45
- }
46
- #prompt > label > textarea {
47
- max-height: 63px;
48
- }
49
- '''
50
-
51
-
52
- # ========== 适配 SillyTavern 的模版 ==========
53
- def text_format(text: str, _env=None, **env):
54
- if _env is not None:
55
- for k, v in _env.items():
56
- text = text.replace(r'{{' + k + r'}}', v)
57
- for k, v in env.items():
58
- text = text.replace(r'{{' + k + r'}}', v)
59
- return text
60
-
61
-
62
- # ========== 哈希函数 ==========
63
- def x_hash(x: str):
64
- return hashlib.sha1(x.encode('utf-8')).hexdigest()
65
-
66
-
67
- # ========== 读取配置文件 ==========
68
- with open('rp_config.json', encoding='utf-8') as f:
69
- tmp = f.read()
70
- with open('rp_sample_config.json', encoding='utf-8') as f:
71
- cfg = json.load(f)
72
- cfg['setting_cache_path']['value'] += x_hash(tmp)
73
- cfg.update(json.loads(tmp))
74
-
75
- # ========== 给引号加粗 ==========
76
- reg_q = re.compile(r'“(.+?)”')
77
-
78
-
79
- def chat_display_format(text: str):
80
- return reg_q.sub(r' **\g<0>** ', text)
81
-
82
-
83
- # ========== 温度、采样之类的设置 ==========
84
- with gr.Blocks() as setting:
85
- with gr.Row():
86
- setting_path = gr.Textbox(label="模型路径", max_lines=1, scale=2, **cfg['setting_path'])
87
- setting_cache_path = gr.Textbox(label="缓存路径", max_lines=1, scale=2, **cfg['setting_cache_path'])
88
- setting_seed = gr.Number(label="随机种子", scale=1, **cfg['setting_seed'])
89
- setting_n_gpu_layers = gr.Number(label="n_gpu_layers", scale=1, **cfg['setting_n_gpu_layers'])
90
- with gr.Row():
91
- setting_ctx = gr.Number(label="上下文大小(Tokens)", **cfg['setting_ctx'])
92
- setting_max_tokens = gr.Number(label="最大响应长度(Tokens)", interactive=True, **cfg['setting_max_tokens'])
93
- setting_n_keep = gr.Number(value=10, label="n_keep", interactive=False)
94
- setting_n_discard = gr.Number(label="n_discard", interactive=True, **cfg['setting_n_discard'])
95
- with gr.Row():
96
- setting_temperature = gr.Number(label="温度", interactive=True, **cfg['setting_temperature'])
97
- setting_repeat_penalty = gr.Number(label="重复惩罚", interactive=True, **cfg['setting_repeat_penalty'])
98
- setting_frequency_penalty = gr.Number(label="频率惩罚", interactive=True, **cfg['setting_frequency_penalty'])
99
- setting_presence_penalty = gr.Number(label="存在惩罚", interactive=True, **cfg['setting_presence_penalty'])
100
- setting_repeat_last_n = gr.Number(label="惩罚范围", interactive=True, **cfg['setting_repeat_last_n'])
101
- with gr.Row():
102
- setting_top_k = gr.Number(label="Top-K", interactive=True, **cfg['setting_top_k'])
103
- setting_top_p = gr.Number(label="Top P", interactive=True, **cfg['setting_top_p'])
104
- setting_min_p = gr.Number(label="Min P", interactive=True, **cfg['setting_min_p'])
105
- setting_typical_p = gr.Number(label="Typical", interactive=True, **cfg['setting_typical_p'])
106
- setting_tfs_z = gr.Number(label="TFS", interactive=True, **cfg['setting_tfs_z'])
107
- with gr.Row():
108
- setting_mirostat_mode = gr.Number(label="Mirostat 模式", **cfg['setting_mirostat_mode'])
109
- setting_mirostat_eta = gr.Number(label="Mirostat 学习率", interactive=True, **cfg['setting_mirostat_eta'])
110
- setting_mirostat_tau = gr.Number(label="Mirostat 目标熵", interactive=True, **cfg['setting_mirostat_tau'])
111
-
112
- # ========== 下载模型 ==========
113
- if os.path.exists(setting_path.value):
114
- print(f"The file {setting_path.value} exists.")
115
- else:
116
- from huggingface_hub import snapshot_download
117
-
118
- os.mkdir("downloads")
119
- os.mkdir("cache")
120
- snapshot_download(repo_id='TheBloke/CausalLM-7B-GGUF', local_dir=r'downloads',
121
- allow_patterns='causallm_7b.Q5_K_M.gguf')
122
- snapshot_download(repo_id='Limour/llama-python-streamingllm-cache', repo_type='dataset', local_dir=r'cache')
123
-
124
- # ========== 加载模型 ==========
125
- model = StreamingLLM(model_path=setting_path.value,
126
- seed=setting_seed.value,
127
- n_gpu_layers=setting_n_gpu_layers.value,
128
- n_ctx=setting_ctx.value)
129
- setting_ctx.value = model.n_ctx()
130
-
131
- # ========== 聊天的模版 默认 chatml ==========
132
- chat_template = ChatTemplate(model)
133
-
134
- # ========== 展示角色卡 ==========
135
- with gr.Blocks() as role:
136
- with gr.Row():
137
- role_usr = gr.Textbox(label="用户名称", max_lines=1, interactive=False, **cfg['role_usr'])
138
- role_char = gr.Textbox(label="角色名称", max_lines=1, interactive=False, **cfg['role_char'])
139
-
140
- role_char_d = gr.Textbox(lines=10, label="故事描述", **cfg['role_char_d'])
141
- role_chat_style = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
142
-
143
- # model.eval_t([1]) # 这个暖机的 bos [1] 删了就不正常了
144
- if os.path.exists(setting_cache_path.value):
145
- # ========== 加载角色卡-缓存 ==========
146
- tmp = model.load_session(setting_cache_path.value)
147
- print(f'load cache from {setting_cache_path.value} {tmp}')
148
- tmp = chat_template('system',
149
- text_format(role_char_d.value,
150
- char=role_char.value,
151
- user=role_usr.value))
152
- setting_n_keep.value = len(tmp)
153
- tmp = chat_template(role_char.value,
154
- text_format(role_chat_style.value,
155
- char=role_char.value,
156
- user=role_usr.value))
157
- setting_n_keep.value += len(tmp)
158
- # ========== 加载角色卡-第一条消息 ==========
159
- chatbot = []
160
- for one in cfg["role_char_first"]:
161
- one['name'] = text_format(one['name'],
162
- char=role_char.value,
163
- user=role_usr.value)
164
- one['value'] = text_format(one['value'],
165
- char=role_char.value,
166
- user=role_usr.value)
167
- if one['name'] == role_char.value:
168
- chatbot.append((None, chat_display_format(one['value'])))
169
- print(one)
170
- else:
171
- # ========== 加载角色卡-角色描述 ==========
172
- tmp = chat_template('system',
173
- text_format(role_char_d.value,
174
- char=role_char.value,
175
- user=role_usr.value))
176
- setting_n_keep.value = model.eval_t(tmp) # 此内容永久存在
177
-
178
- # ========== 加载角色卡-回复示例 ==========
179
- tmp = chat_template(role_char.value,
180
- text_format(role_chat_style.value,
181
- char=role_char.value,
182
- user=role_usr.value))
183
- setting_n_keep.value = model.eval_t(tmp) # 此内容永久存在
184
-
185
- # ========== 加载角色卡-第一条消息 ==========
186
- chatbot = []
187
- for one in cfg["role_char_first"]:
188
- one['name'] = text_format(one['name'],
189
- char=role_char.value,
190
- user=role_usr.value)
191
- one['value'] = text_format(one['value'],
192
- char=role_char.value,
193
- user=role_usr.value)
194
- if one['name'] == role_char.value:
195
- chatbot.append((None, chat_display_format(one['value'])))
196
- print(one)
197
- tmp = chat_template(one['name'], one['value'])
198
- model.eval_t(tmp) # 此内容随上下文增加将被丢弃
199
-
200
- # ========== 保存角色卡-缓存 ==========
201
- with open(setting_cache_path.value, 'wb') as f:
202
- pass
203
- tmp = model.save_session(setting_cache_path.value)
204
- print(f'save cache {tmp}')
205
- # ========== 上传缓存 ==========
206
- from huggingface_hub import login, CommitScheduler
207
-
208
- login(token=os.environ.get("HF_TOKEN"), write_permission=True)
209
- CommitScheduler(repo_id='Limour/llama-python-streamingllm-cache', repo_type='dataset', folder_path='cache')
210
-
211
-
212
- # ========== 流式输出函数 ==========
213
- def btn_submit_com(_n_keep, _n_discard,
214
- _temperature, _repeat_penalty, _frequency_penalty,
215
- _presence_penalty, _repeat_last_n, _top_k,
216
- _top_p, _min_p, _typical_p,
217
- _tfs_z, _mirostat_mode, _mirostat_eta,
218
- _mirostat_tau, _role, _max_tokens):
219
- with lock:
220
- if not session_active:
221
- raise RuntimeError
222
- # ========== 初始化输出模版 ==========
223
- t_bot = chat_template(_role)
224
- completion_tokens = [] # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
225
- history = ''
226
- # ========== 流式输出 ==========
227
- for token in model.generate_t(
228
- tokens=t_bot,
229
- n_keep=_n_keep,
230
- n_discard=_n_discard,
231
- im_start=chat_template.im_start_token,
232
- top_k=_top_k,
233
- top_p=_top_p,
234
- min_p=_min_p,
235
- typical_p=_typical_p,
236
- temp=_temperature,
237
- repeat_penalty=_repeat_penalty,
238
- repeat_last_n=_repeat_last_n,
239
- frequency_penalty=_frequency_penalty,
240
- presence_penalty=_presence_penalty,
241
- tfs_z=_tfs_z,
242
- mirostat_mode=_mirostat_mode,
243
- mirostat_tau=_mirostat_tau,
244
- mirostat_eta=_mirostat_eta,
245
- ):
246
- if token in chat_template.eos or token == chat_template.nlnl:
247
- t_bot.extend(completion_tokens)
248
- print('token in eos', token)
249
- break
250
- completion_tokens.append(token)
251
- all_text = model.str_detokenize(completion_tokens)
252
- if not all_text:
253
- continue
254
- t_bot.extend(completion_tokens)
255
- history += all_text
256
- yield history
257
- if token in chat_template.onenl:
258
- # ========== 移除末尾的换行符 ==========
259
- if t_bot[-2] in chat_template.onenl:
260
- model.venv_pop_token()
261
- break
262
- if t_bot[-2] in chat_template.onerl and t_bot[-3] in chat_template.onenl:
263
- model.venv_pop_token()
264
- break
265
- if history[-2:] == '\n\n': # 各种 'x\n\n' 的token,比如'。\n\n'
266
- print('t_bot[-4:]', t_bot[-4:], repr(model.str_detokenize(t_bot[-4:])),
267
- repr(model.str_detokenize(t_bot[-1:])))
268
- break
269
- if len(t_bot) > _max_tokens:
270
- break
271
- completion_tokens = []
272
- # ========== 查看末尾的换行符 ==========
273
- print('history', repr(history))
274
- # ========== 给 kv_cache 加上输出结束符 ==========
275
- model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard)
276
- t_bot.extend(chat_template.im_end_nl)
277
-
278
-
279
- # ========== 显示用户消息 ==========
280
- def btn_submit_usr(message: str, history):
281
- global session_active
282
- with lock:
283
- if session_active:
284
- raise RuntimeError
285
- session_active = True
286
- # print('btn_submit_usr', message, history)
287
- if history is None:
288
- history = []
289
- return "", history + [[message.strip(), '']], gr.update(interactive=False)
290
-
291
-
292
- # ========== 模型流式响应 ==========
293
- def btn_submit_bot(history, _n_keep, _n_discard,
294
- _temperature, _repeat_penalty, _frequency_penalty,
295
- _presence_penalty, _repeat_last_n, _top_k,
296
- _top_p, _min_p, _typical_p,
297
- _tfs_z, _mirostat_mode, _mirostat_eta,
298
- _mirostat_tau, _usr, _char,
299
- _rag, _max_tokens):
300
- with lock:
301
- if not session_active:
302
- raise RuntimeError
303
- # ========== 需要临时注入的内容 ==========
304
- rag_idx = None
305
- if len(_rag) > 0:
306
- rag_idx = model.venv_create() # 记录 venv_idx
307
- t_rag = chat_template('system', _rag)
308
- model.eval_t(t_rag, _n_keep, _n_discard)
309
- model.venv_create() # 与 t_rag 隔离
310
- # ========== 用户输入 ==========
311
- t_msg = history[-1][0]
312
- t_msg = chat_template(_usr, t_msg)
313
- model.eval_t(t_msg, _n_keep, _n_discard)
314
- # ========== 模型输出 ==========
315
- _tmp = btn_submit_com(_n_keep, _n_discard,
316
- _temperature, _repeat_penalty, _frequency_penalty,
317
- _presence_penalty, _repeat_last_n, _top_k,
318
- _top_p, _min_p, _typical_p,
319
- _tfs_z, _mirostat_mode, _mirostat_eta,
320
- _mirostat_tau, _char, _max_tokens)
321
- for _h in _tmp:
322
- history[-1][1] = _h
323
- yield history, str((model.n_tokens, model.venv))
324
- # ========== 输出完毕后格式化输出 ==========
325
- history[-1][1] = chat_display_format(history[-1][1])
326
- yield history, str((model.n_tokens, model.venv))
327
- # ========== 及时清理上一次生成的旁白 ==========
328
- if vo_idx > 0:
329
- print('vo_idx', vo_idx, model.venv)
330
- model.venv_remove(vo_idx)
331
- print('vo_idx', vo_idx, model.venv)
332
- if rag_idx and vo_idx < rag_idx:
333
- rag_idx -= 1
334
- # ========== 响应完毕后清除注入的内容 ==========
335
- if rag_idx is not None:
336
- model.venv_remove(rag_idx) # 销毁对应的 venv
337
- model.venv_disband() # 退出隔离环境
338
- yield history, str((model.n_tokens, model.venv))
339
- print('venv_disband', vo_idx, model.venv)
340
-
341
-
342
- # ========== 待实现 ==========
343
- def btn_rag_(_rag, _msg):
344
- retn = ''
345
- return retn
346
-
347
-
348
- vo_idx = 0
349
-
350
-
351
- # ========== 输出一段旁白 ==========
352
- def btn_submit_vo(_n_keep, _n_discard,
353
- _temperature, _repeat_penalty, _frequency_penalty,
354
- _presence_penalty, _repeat_last_n, _top_k,
355
- _top_p, _min_p, _typical_p,
356
- _tfs_z, _mirostat_mode, _mirostat_eta,
357
- _mirostat_tau, _max_tokens):
358
- with lock:
359
- if not session_active:
360
- raise RuntimeError
361
- global vo_idx
362
- vo_idx = model.venv_create() # 创建隔离环境
363
- # ========== 模型输出旁白 ==========
364
- _tmp = btn_submit_com(_n_keep, _n_discard,
365
- _temperature, _repeat_penalty, _frequency_penalty,
366
- _presence_penalty, _repeat_last_n, _top_k,
367
- _top_p, _min_p, _typical_p,
368
- _tfs_z, _mirostat_mode, _mirostat_eta,
369
- _mirostat_tau, '旁白', _max_tokens)
370
- for _h in _tmp:
371
- yield _h, str((model.n_tokens, model.venv))
372
-
373
-
374
- # ========== 给用户提供默认回复 ==========
375
- def btn_submit_suggest(_n_keep, _n_discard,
376
- _temperature, _repeat_penalty, _frequency_penalty,
377
- _presence_penalty, _repeat_last_n, _top_k,
378
- _top_p, _min_p, _typical_p,
379
- _tfs_z, _mirostat_mode, _mirostat_eta,
380
- _mirostat_tau, _usr, _max_tokens):
381
- with lock:
382
- if not session_active:
383
- raise RuntimeError
384
- model.venv_create() # 创建隔离环境
385
- # ========== 模型输出 ==========
386
- _tmp = btn_submit_com(_n_keep, _n_discard,
387
- _temperature, _repeat_penalty, _frequency_penalty,
388
- _presence_penalty, _repeat_last_n, _top_k,
389
- _top_p, _min_p, _typical_p,
390
- _tfs_z, _mirostat_mode, _mirostat_eta,
391
- _mirostat_tau, _usr, _max_tokens)
392
- _h = ''
393
- for _h in _tmp:
394
- yield _h, str((model.n_tokens, model.venv))
395
- model.venv_remove() # 销毁隔离环境
396
- yield _h, str((model.n_tokens, model.venv))
397
-
398
-
399
- def btn_submit_finish():
400
- global session_active
401
- with lock:
402
- if not session_active:
403
- raise RuntimeError
404
- session_active = False
405
- return gr.update(interactive=True)
406
-
407
-
408
- # ========== 聊天页面 ==========
409
- with gr.Blocks() as chatting:
410
- with gr.Row(equal_height=True):
411
- chatbot = gr.Chatbot(height='60vh', scale=2, value=chatbot,
412
- avatar_images=(r'assets/user.png', r'assets/chatbot.webp'))
413
- with gr.Column(scale=1, elem_id="area"):
414
- rag = gr.Textbox(label='RAG', show_copy_button=True, elem_id="RAG-area")
415
- vo = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
416
- s_info = gr.Textbox(value=str((model.n_tokens, model.venv)), max_lines=1, label='info', interactive=False)
417
- msg = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
418
- with gr.Row():
419
- btn_rag = gr.Button("RAG")
420
- btn_submit = gr.Button("Submit")
421
- btn_retry = gr.Button("Retry")
422
- btn_com1 = gr.Button("自定义1")
423
- btn_com2 = gr.Button("自定义2")
424
- btn_com3 = gr.Button("自定义3")
425
-
426
- btn_rag.click(fn=btn_rag_, outputs=rag,
427
- inputs=[rag, msg])
428
-
429
- btn_submit.click(
430
- fn=btn_submit_usr, api_name="submit",
431
- inputs=[msg, chatbot],
432
- outputs=[msg, chatbot, btn_submit]
433
- ).success(
434
- fn=btn_submit_bot,
435
- inputs=[chatbot, setting_n_keep, setting_n_discard,
436
- setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
437
- setting_presence_penalty, setting_repeat_last_n, setting_top_k,
438
- setting_top_p, setting_min_p, setting_typical_p,
439
- setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
440
- setting_mirostat_tau, role_usr, role_char,
441
- rag, setting_max_tokens],
442
- outputs=[chatbot, s_info]
443
- ).success(
444
- fn=btn_submit_vo,
445
- inputs=[setting_n_keep, setting_n_discard,
446
- setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
447
- setting_presence_penalty, setting_repeat_last_n, setting_top_k,
448
- setting_top_p, setting_min_p, setting_typical_p,
449
- setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
450
- setting_mirostat_tau, setting_max_tokens],
451
- outputs=[vo, s_info]
452
- ).success(
453
- fn=btn_submit_suggest,
454
- inputs=[setting_n_keep, setting_n_discard,
455
- setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
456
- setting_presence_penalty, setting_repeat_last_n, setting_top_k,
457
- setting_top_p, setting_min_p, setting_typical_p,
458
- setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
459
- setting_mirostat_tau, role_usr, setting_max_tokens],
460
- outputs=[msg, s_info]
461
- ).success(
462
- fn=btn_submit_finish,
463
- outputs=btn_submit
464
- )
465
-
466
-
467
- # ========== 用于调试 ==========
468
- # btn_com1.click(fn=lambda: model.str_detokenize(model._input_ids), outputs=rag)
469
-
470
- @btn_com2.click(inputs=setting_cache_path,
471
- outputs=[s_info, btn_submit])
472
- def btn_com2(_cache_path):
473
- try:
474
- with lock:
475
- _tmp = model.load_session(setting_cache_path.value)
476
- print(f'load cache from {setting_cache_path.value} {_tmp}')
477
- global vo_idx
478
- vo_idx = 0
479
- model.venv = [0]
480
- global session_active
481
- session_active = False
482
- return str((model.n_tokens, model.venv)), gr.update(interactive=True)
483
- except Exception as e:
484
- restart_space()
485
- raise e
486
-
487
- # @btn_com3.click()
488
- # def btn_com3():
489
- # restart_space()
490
-
491
- # ========== 开始运行 ==========
492
- demo = gr.TabbedInterface([chatting, setting, role],
493
- ["聊天", "设置", '角色'],
494
- css=custom_css)
495
- gr.close_all()
496
- demo.queue(max_size=1).launch(share=False)