File size: 7,066 Bytes
0523803 670a6e9 0523803 670a6e9 0523803 670a6e9 0523803 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import gradio as gr
import threading
from llama_cpp_python_streamingllm import StreamingLLM
from mods.read_cfg import cfg
from mods.text_display import init as text_display_init
from mods.btn_rag import init as btn_rag_init
# ========== 按钮中用到的共同的函数 ==========
from mods.btn_com import init as btn_com_init
# ========== 输出一段回答 ==========
from mods.btn_submit import init as btn_submit_init
# ========== 输出一段旁白 ==========
from mods.btn_vo import init as btn_vo_init
# ========== 重新输出一段回答 ==========
from mods.btn_retry import init as btn_retry_init
# ========== 给用户提供默认回复的建议 ==========
from mods.btn_suggest import init as btn_suggest_init
# ========== 重置按钮 ==========
from mods.btn_reset import init as btn_reset_init
# ========== 聊天的模版 默认 chatml ==========
from chat_template import ChatTemplate
# ========== 全局锁,确保只能进行一个会话 ==========
cfg['session_lock'] = threading.Lock()
cfg['session_active'] = False
# ========== 温度、采样之类的设置 ==========
with gr.Blocks() as setting:
with gr.Row():
cfg['setting_path'] = gr.Textbox(label="模型路径", max_lines=1, scale=2, **cfg['setting_path'])
cfg['setting_cache_path'] = gr.Textbox(label="缓存路径", max_lines=1, scale=2, **cfg['setting_cache_path'])
cfg['setting_seed'] = gr.Number(label="随机种子", scale=1, **cfg['setting_seed'])
cfg['setting_n_gpu_layers'] = gr.Number(label="n_gpu_layers", scale=1, **cfg['setting_n_gpu_layers'])
with gr.Row():
cfg['setting_ctx'] = gr.Number(label="上下文大小(Tokens)", **cfg['setting_ctx'])
cfg['setting_max_tokens'] = gr.Number(label="最大响应长度(Tokens)", interactive=True,
**cfg['setting_max_tokens'])
cfg['setting_n_keep'] = gr.Number(value=10, label="n_keep", interactive=False)
cfg['setting_n_discard'] = gr.Number(label="n_discard", interactive=True, **cfg['setting_n_discard'])
with gr.Row():
cfg['setting_temperature'] = gr.Number(label="温度", interactive=True, **cfg['setting_temperature'])
cfg['setting_repeat_penalty'] = gr.Number(label="重复惩罚", interactive=True, **cfg['setting_repeat_penalty'])
cfg['setting_frequency_penalty'] = gr.Number(label="频率惩罚", interactive=True,
**cfg['setting_frequency_penalty'])
cfg['setting_presence_penalty'] = gr.Number(label="存在惩罚", interactive=True,
**cfg['setting_presence_penalty'])
cfg['setting_repeat_last_n'] = gr.Number(label="惩罚范围", interactive=True, **cfg['setting_repeat_last_n'])
with gr.Row():
cfg['setting_top_k'] = gr.Number(label="Top-K", interactive=True, **cfg['setting_top_k'])
cfg['setting_top_p'] = gr.Number(label="Top P", interactive=True, **cfg['setting_top_p'])
cfg['setting_min_p'] = gr.Number(label="Min P", interactive=True, **cfg['setting_min_p'])
cfg['setting_typical_p'] = gr.Number(label="Typical", interactive=True, **cfg['setting_typical_p'])
cfg['setting_tfs_z'] = gr.Number(label="TFS", interactive=True, **cfg['setting_tfs_z'])
with gr.Row():
cfg['setting_mirostat_mode'] = gr.Number(label="Mirostat 模式", **cfg['setting_mirostat_mode'])
cfg['setting_mirostat_eta'] = gr.Number(label="Mirostat 学习率", interactive=True,
**cfg['setting_mirostat_eta'])
cfg['setting_mirostat_tau'] = gr.Number(label="Mirostat 目标熵", interactive=True,
**cfg['setting_mirostat_tau'])
# ========== 加载模型 ==========
cfg['model'] = StreamingLLM(model_path=cfg['setting_path'].value,
seed=cfg['setting_seed'].value,
n_gpu_layers=cfg['setting_n_gpu_layers'].value,
n_ctx=cfg['setting_ctx'].value)
cfg['chat_template'] = ChatTemplate(cfg['model'])
cfg['setting_ctx'].value = cfg['model'].n_ctx()
# ========== 展示角色卡 ==========
with gr.Blocks() as role:
with gr.Row():
cfg['role_usr'] = gr.Textbox(label="用户名称", max_lines=1, interactive=False, **cfg['role_usr'])
cfg['role_char'] = gr.Textbox(label="角色名称", max_lines=1, interactive=False, **cfg['role_char'])
cfg['role_char_d'] = gr.Textbox(lines=10, label="故事描述", **cfg['role_char_d'])
cfg['role_chat_style'] = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
# ========== 加载角色卡-缓存 ==========
from mods.load_cache import init as load_cache_init
text_display_init(cfg)
load_cache_init(cfg)
# ========== 聊天页面 ==========
with gr.Blocks() as chatting:
with gr.Row(equal_height=True):
cfg['chatbot'] = gr.Chatbot(height='60vh', scale=2, value=cfg['chatbot'],
avatar_images=(r'assets/user.png', r'assets/chatbot.webp'))
with gr.Column(scale=1, elem_id="area"):
cfg['rag'] = gr.Textbox(label='RAG', show_copy_button=True, elem_id="RAG-area")
cfg['vo'] = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
cfg['s_info'] = gr.Textbox(value=cfg['model'].venv_info, max_lines=1, label='info', interactive=False)
cfg['msg'] = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
with gr.Row():
cfg['btn_vo'] = gr.Button("旁白")
cfg['btn_rag'] = gr.Button("RAG")
cfg['btn_retry'] = gr.Button("Retry")
cfg['btn_com1'] = gr.Button("自定义1")
cfg['btn_reset'] = gr.Button("Reset")
cfg['btn_debug'] = gr.Button("Debug")
cfg['btn_submit'] = gr.Button("Submit")
cfg['btn_suggest'] = gr.Button("建议")
cfg['gr'] = gr
btn_com_init(cfg)
btn_rag_init(cfg)
btn_submit_init(cfg)
btn_vo_init(cfg)
btn_suggest_init(cfg)
btn_retry_init(cfg)
# ========== 用于调试 ==========
btn_reset_init(cfg)
# ========== 让聊天界面的文本框等高 ==========
custom_css = r'''
#area > div {
height: 100%;
}
#RAG-area {
flex-grow: 1;
}
#RAG-area > label {
height: 100%;
display: flex;
flex-direction: column;
}
#RAG-area > label > textarea {
flex-grow: 1;
max-height: 20vh;
}
#VO-area {
flex-grow: 1;
}
#VO-area > label {
height: 100%;
display: flex;
flex-direction: column;
}
#VO-area > label > textarea {
flex-grow: 1;
max-height: 20vh;
}
#prompt > label > textarea {
max-height: 63px;
}
'''
# ========== 开始运行 ==========
demo = gr.TabbedInterface([chatting, setting, role],
["聊天", "设置", '角色'],
css=custom_css)
gr.close_all()
demo.queue(api_open=False, max_size=1).launch(share=False)
|