File size: 18,578 Bytes
bac870d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
import hashlib
import os
import re
import json

import gradio as gr

from chat_template import ChatTemplate
from llama_cpp_python_streamingllm import StreamingLLM

#  ========== 让聊天界面的文本框等高 ==========
custom_css = r'''
#area > div {
    height: 100%;
}
#RAG-area {
    flex-grow: 1;
}
#RAG-area > label {
    height: 100%;
    display: flex;
    flex-direction: column;
}
#RAG-area > label > textarea {
    flex-grow: 1;
    max-height: 20vh;
}
#VO-area {
    flex-grow: 1;
}
#VO-area > label {
    height: 100%;
    display: flex;
    flex-direction: column;
}
#VO-area > label > textarea {
    flex-grow: 1;
    max-height: 20vh;
}
#prompt > label > textarea {
    max-height: 63px;
}
'''


#  ========== 适配 SillyTavern 的模版 ==========
def text_format(text: str, _env=None, **env):
    if _env is not None:
        for k, v in _env.items():
            text = text.replace(r'{{' + k + r'}}', v)
    for k, v in env.items():
        text = text.replace(r'{{' + k + r'}}', v)
    return text


# ========== 哈希函数 ==========
def x_hash(x: str):
    return hashlib.sha1(x.encode('utf-8')).hexdigest()


# ========== 读取配置文件 ==========
with open('rp_config.json', encoding='utf-8') as f:
    tmp = f.read()
with open('rp_sample_config.json', encoding='utf-8') as f:
    cfg = json.load(f)
cfg['setting_cache_path']['value'] += x_hash(tmp)
cfg.update(json.loads(tmp))

#  ========== 给引号加粗 ==========
reg_q = re.compile(r'“(.+?)”')


def chat_display_format(text: str):
    return reg_q.sub(r' **\g<0>** ', text)


#  ========== 温度、采样之类的设置 ==========
with gr.Blocks() as setting:
    with gr.Row():
        setting_path = gr.Textbox(label="模型路径", max_lines=1, scale=2, **cfg['setting_path'])
        setting_cache_path = gr.Textbox(label="缓存路径", max_lines=1, scale=2, **cfg['setting_cache_path'])
        setting_seed = gr.Number(label="随机种子", scale=1, **cfg['setting_seed'])
        setting_n_gpu_layers = gr.Number(label="n_gpu_layers", scale=1, **cfg['setting_n_gpu_layers'])
    with gr.Row():
        setting_ctx = gr.Number(label="上下文大小(Tokens)", **cfg['setting_ctx'])
        setting_max_tokens = gr.Number(label="最大响应长度(Tokens)", interactive=True, **cfg['setting_max_tokens'])
        setting_n_keep = gr.Number(value=10, label="n_keep", interactive=False)
        setting_n_discard = gr.Number(label="n_discard", interactive=True, **cfg['setting_n_discard'])
    with gr.Row():
        setting_temperature = gr.Number(label="温度", interactive=True, **cfg['setting_temperature'])
        setting_repeat_penalty = gr.Number(label="重复惩罚", interactive=True, **cfg['setting_repeat_penalty'])
        setting_frequency_penalty = gr.Number(label="频率惩罚", interactive=True, **cfg['setting_frequency_penalty'])
        setting_presence_penalty = gr.Number(label="存在惩罚", interactive=True, **cfg['setting_presence_penalty'])
        setting_repeat_last_n = gr.Number(label="惩罚范围", interactive=True, **cfg['setting_repeat_last_n'])
    with gr.Row():
        setting_top_k = gr.Number(label="Top-K", interactive=True, **cfg['setting_top_k'])
        setting_top_p = gr.Number(label="Top P", interactive=True, **cfg['setting_top_p'])
        setting_min_p = gr.Number(label="Min P", interactive=True, **cfg['setting_min_p'])
        setting_typical_p = gr.Number(label="Typical", interactive=True, **cfg['setting_typical_p'])
        setting_tfs_z = gr.Number(label="TFS", interactive=True, **cfg['setting_tfs_z'])
    with gr.Row():
        setting_mirostat_mode = gr.Number(label="Mirostat 模式", **cfg['setting_mirostat_mode'])
        setting_mirostat_eta = gr.Number(label="Mirostat 学习率", interactive=True, **cfg['setting_mirostat_eta'])
        setting_mirostat_tau = gr.Number(label="Mirostat 目标熵", interactive=True, **cfg['setting_mirostat_tau'])

    #  ========== 加载模型 ==========
    model = StreamingLLM(model_path=setting_path.value,
                         seed=setting_seed.value,
                         n_gpu_layers=setting_n_gpu_layers.value,
                         n_ctx=setting_ctx.value)
    setting_ctx.value = model.n_ctx()

# ========== 聊天的模版 默认 chatml ==========
chat_template = ChatTemplate(model)

# ========== 展示角色卡 ==========
with gr.Blocks() as role:
    with gr.Row():
        role_usr = gr.Textbox(label="用户名称", max_lines=1, interactive=False, **cfg['role_usr'])
        role_char = gr.Textbox(label="角色名称", max_lines=1, interactive=False, **cfg['role_char'])

    role_char_d = gr.Textbox(lines=10, label="故事描述", **cfg['role_char_d'])
    role_chat_style = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])

    # model.eval_t([1])  # 这个暖机的 bos [1] 删了就不正常了
    if os.path.exists(setting_cache_path.value):
        # ========== 加载角色卡-缓存 ==========
        tmp = model.load_session(setting_cache_path.value)
        print(f'load cache from {setting_cache_path.value} {tmp}')
        tmp = chat_template('system',
                            text_format(role_char_d.value,
                                        char=role_char.value,
                                        user=role_usr.value))
        setting_n_keep.value = len(tmp)
        tmp = chat_template(role_char.value,
                            text_format(role_chat_style.value,
                                        char=role_char.value,
                                        user=role_usr.value))
        setting_n_keep.value += len(tmp)
        # ========== 加载角色卡-第一条消息 ==========
        chatbot = []
        for one in cfg["role_char_first"]:
            one['name'] = text_format(one['name'],
                                      char=role_char.value,
                                      user=role_usr.value)
            one['value'] = text_format(one['value'],
                                       char=role_char.value,
                                       user=role_usr.value)
            if one['name'] == role_char.value:
                chatbot.append((None, chat_display_format(one['value'])))
            print(one)
    else:
        # ========== 加载角色卡-角色描述 ==========
        tmp = chat_template('system',
                            text_format(role_char_d.value,
                                        char=role_char.value,
                                        user=role_usr.value))
        setting_n_keep.value = model.eval_t(tmp)  # 此内容永久存在

        # ========== 加载角色卡-回复示例 ==========
        tmp = chat_template(role_char.value,
                            text_format(role_chat_style.value,
                                        char=role_char.value,
                                        user=role_usr.value))
        setting_n_keep.value = model.eval_t(tmp)  # 此内容永久存在

        # ========== 加载角色卡-第一条消息 ==========
        chatbot = []
        for one in cfg["role_char_first"]:
            one['name'] = text_format(one['name'],
                                      char=role_char.value,
                                      user=role_usr.value)
            one['value'] = text_format(one['value'],
                                       char=role_char.value,
                                       user=role_usr.value)
            if one['name'] == role_char.value:
                chatbot.append((None, chat_display_format(one['value'])))
            print(one)
            tmp = chat_template(one['name'], one['value'])
            model.eval_t(tmp)  # 此内容随上下文增加将被丢弃

        # ========== 保存角色卡-缓存 ==========
        with open(setting_cache_path.value, 'wb') as f:
            pass
        tmp = model.save_session(setting_cache_path.value)
        print(f'save cache {tmp}')


# ========== 流式输出函数 ==========
def btn_submit_com(_n_keep, _n_discard,
                   _temperature, _repeat_penalty, _frequency_penalty,
                   _presence_penalty, _repeat_last_n, _top_k,
                   _top_p, _min_p, _typical_p,
                   _tfs_z, _mirostat_mode, _mirostat_eta,
                   _mirostat_tau, _role, _max_tokens):
    # ========== 初始化输出模版 ==========
    t_bot = chat_template(_role)
    completion_tokens = []  # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
    history = ''
    # ========== 流式输出 ==========
    for token in model.generate_t(
            tokens=t_bot,
            n_keep=_n_keep,
            n_discard=_n_discard,
            im_start=chat_template.im_start_token,
            top_k=_top_k,
            top_p=_top_p,
            min_p=_min_p,
            typical_p=_typical_p,
            temp=_temperature,
            repeat_penalty=_repeat_penalty,
            repeat_last_n=_repeat_last_n,
            frequency_penalty=_frequency_penalty,
            presence_penalty=_presence_penalty,
            tfs_z=_tfs_z,
            mirostat_mode=_mirostat_mode,
            mirostat_tau=_mirostat_tau,
            mirostat_eta=_mirostat_eta,
    ):
        if token in chat_template.eos or token == chat_template.nlnl:
            t_bot.extend(completion_tokens)
            print('token in eos', token)
            break
        completion_tokens.append(token)
        all_text = model.str_detokenize(completion_tokens)
        if not all_text:
            continue
        t_bot.extend(completion_tokens)
        history += all_text
        yield history
        if token in chat_template.onenl:
            # ========== 移除末尾的换行符 ==========
            if t_bot[-2] in chat_template.onenl:
                model.venv_pop_token()
                break
            if t_bot[-2] in chat_template.onerl and t_bot[-3] in chat_template.onenl:
                model.venv_pop_token()
                break
        if history[-2:] == '\n\n':  # 各种 'x\n\n' 的token,比如'。\n\n'
            print('t_bot[-4:]', t_bot[-4:], repr(model.str_detokenize(t_bot[-4:])),
                  repr(model.str_detokenize(t_bot[-1:])))
            break
        if len(t_bot) > _max_tokens:
            break
        completion_tokens = []
    # ========== 查看末尾的换行符 ==========
    print('history', repr(history))
    # ========== 给 kv_cache 加上输出结束符 ==========
    model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard)
    t_bot.extend(chat_template.im_end_nl)


# ========== 显示用户消息 ==========
def btn_submit_usr(message: str, history):
    # print('btn_submit_usr', message, history)
    if history is None:
        history = []
    return "", history + [[message.strip(), '']], gr.update(interactive=False)


# ========== 模型流式响应 ==========
def btn_submit_bot(history, _n_keep, _n_discard,
                   _temperature, _repeat_penalty, _frequency_penalty,
                   _presence_penalty, _repeat_last_n, _top_k,
                   _top_p, _min_p, _typical_p,
                   _tfs_z, _mirostat_mode, _mirostat_eta,
                   _mirostat_tau, _usr, _char,
                   _rag, _max_tokens):
    # ========== 需要临时注入的内容 ==========
    rag_idx = None
    if len(_rag) > 0:
        rag_idx = model.venv_create()  # 记录 venv_idx
        t_rag = chat_template('system', _rag)
        model.eval_t(t_rag, _n_keep, _n_discard)
    model.venv_create()  # 与 t_rag 隔离
    # ========== 用户输入 ==========
    t_msg = history[-1][0]
    t_msg = chat_template(_usr, t_msg)
    model.eval_t(t_msg, _n_keep, _n_discard)
    # ========== 模型输出 ==========
    _tmp = btn_submit_com(_n_keep, _n_discard,
                          _temperature, _repeat_penalty, _frequency_penalty,
                          _presence_penalty, _repeat_last_n, _top_k,
                          _top_p, _min_p, _typical_p,
                          _tfs_z, _mirostat_mode, _mirostat_eta,
                          _mirostat_tau, _char, _max_tokens)
    for _h in _tmp:
        history[-1][1] = _h
        yield history, str((model.n_tokens, model.venv))
    # ========== 输出完毕后格式化输出 ==========
    history[-1][1] = chat_display_format(history[-1][1])
    yield history, str((model.n_tokens, model.venv))
    # ========== 及时清理上一次生成的旁白 ==========
    if vo_idx > 0:
        print('vo_idx', vo_idx, model.venv)
        model.venv_remove(vo_idx)
        print('vo_idx', vo_idx, model.venv)
        if rag_idx and vo_idx < rag_idx:
            rag_idx -= 1
    # ========== 响应完毕后清除注入的内容 ==========
    if rag_idx is not None:
        model.venv_remove(rag_idx)  # 销毁对应的 venv
    model.venv_disband()  # 退出隔离环境
    yield history, str((model.n_tokens, model.venv))
    print('venv_disband', vo_idx, model.venv)


# ========== 待实现 ==========
def btn_rag_(_rag, _msg):
    retn = ''
    return retn


vo_idx = 0


# ========== 输出一段旁白 ==========
def btn_submit_vo(_n_keep, _n_discard,
                  _temperature, _repeat_penalty, _frequency_penalty,
                  _presence_penalty, _repeat_last_n, _top_k,
                  _top_p, _min_p, _typical_p,
                  _tfs_z, _mirostat_mode, _mirostat_eta,
                  _mirostat_tau, _max_tokens):
    global vo_idx
    vo_idx = model.venv_create()  # 创建隔离环境
    # ========== 模型输出旁白 ==========
    _tmp = btn_submit_com(_n_keep, _n_discard,
                          _temperature, _repeat_penalty, _frequency_penalty,
                          _presence_penalty, _repeat_last_n, _top_k,
                          _top_p, _min_p, _typical_p,
                          _tfs_z, _mirostat_mode, _mirostat_eta,
                          _mirostat_tau, '旁白', _max_tokens)
    for _h in _tmp:
        yield _h, str((model.n_tokens, model.venv))


# ========== 给用户提供默认回复 ==========
def btn_submit_suggest(_n_keep, _n_discard,
                       _temperature, _repeat_penalty, _frequency_penalty,
                       _presence_penalty, _repeat_last_n, _top_k,
                       _top_p, _min_p, _typical_p,
                       _tfs_z, _mirostat_mode, _mirostat_eta,
                       _mirostat_tau, _usr, _max_tokens):
    model.venv_create()  # 创建隔离环境
    # ========== 模型输出 ==========
    _tmp = btn_submit_com(_n_keep, _n_discard,
                          _temperature, _repeat_penalty, _frequency_penalty,
                          _presence_penalty, _repeat_last_n, _top_k,
                          _top_p, _min_p, _typical_p,
                          _tfs_z, _mirostat_mode, _mirostat_eta,
                          _mirostat_tau, _usr, _max_tokens)
    _h = ''
    for _h in _tmp:
        yield _h, str((model.n_tokens, model.venv))
    model.venv_remove()  # 销毁隔离环境
    yield _h, str((model.n_tokens, model.venv))


# ========== 聊天页面 ==========
with gr.Blocks() as chatting:
    with gr.Row(equal_height=True):
        chatbot = gr.Chatbot(height='60vh', scale=2, value=chatbot,
                             avatar_images=(r'assets/user.png', r'assets/chatbot.webp'))
        with gr.Column(scale=1, elem_id="area"):
            rag = gr.Textbox(label='RAG', show_copy_button=True, elem_id="RAG-area")
            vo = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
            s_info = gr.Textbox(value=str((model.n_tokens, model.venv)), max_lines=1, label='info', interactive=False)
    msg = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
    with gr.Row():
        btn_rag = gr.Button("RAG")
        btn_submit = gr.Button("Submit")
        btn_retry = gr.Button("Retry")
        btn_com1 = gr.Button("自定义1")
        btn_com2 = gr.Button("自定义2")
        btn_com3 = gr.Button("自定义3")

    btn_rag.click(fn=btn_rag_, outputs=rag,
                  inputs=[rag, msg])

    btn_submit.click(
        fn=btn_submit_usr, api_name="submit",
        inputs=[msg, chatbot],
        outputs=[msg, chatbot, btn_submit]
    ).then(
        fn=btn_submit_bot,
        inputs=[chatbot, setting_n_keep, setting_n_discard,
                setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
                setting_presence_penalty, setting_repeat_last_n, setting_top_k,
                setting_top_p, setting_min_p, setting_typical_p,
                setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
                setting_mirostat_tau, role_usr, role_char,
                rag, setting_max_tokens],
        outputs=[chatbot, s_info]
    ).then(
        fn=btn_submit_vo,
        inputs=[setting_n_keep, setting_n_discard,
                setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
                setting_presence_penalty, setting_repeat_last_n, setting_top_k,
                setting_top_p, setting_min_p, setting_typical_p,
                setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
                setting_mirostat_tau, setting_max_tokens],
        outputs=[vo, s_info]
    ).then(
        fn=btn_submit_suggest,
        inputs=[setting_n_keep, setting_n_discard,
                setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
                setting_presence_penalty, setting_repeat_last_n, setting_top_k,
                setting_top_p, setting_min_p, setting_typical_p,
                setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
                setting_mirostat_tau, role_usr, setting_max_tokens],
        outputs=[msg, s_info]
    ).then(
        fn=lambda: gr.update(interactive=True),
        outputs=btn_submit
    )

    # ========== 用于调试 ==========
    btn_com1.click(fn=lambda: model.str_detokenize(model._input_ids), outputs=rag)


    @btn_com2.click(inputs=setting_cache_path,
                    outputs=s_info)
    def btn_com2(_cache_path):
        _tmp = model.load_session(setting_cache_path.value)
        print(f'load cache from {setting_cache_path.value} {_tmp}')
        global vo_idx
        vo_idx = 0
        model.venv = [0]
        return str((model.n_tokens, model.venv))

    # ========== 开始运行 ==========
demo = gr.TabbedInterface([chatting, setting, role],
                          ["聊天", "设置", '角色'],
                          css=custom_css)
gr.close_all()
demo.queue().launch(share=False)