Limour's picture
Upload 2 files
91dc74f verified
raw
history blame
4.38 kB
def init(cfg):
chat_template = cfg['chat_template']
model = cfg['model']
gr = cfg['gr']
lock = cfg['session_lock']
# ========== 流式输出函数 ==========
def btn_com(_n_keep, _n_discard,
_temperature, _repeat_penalty, _frequency_penalty,
_presence_penalty, _repeat_last_n, _top_k,
_top_p, _min_p, _typical_p,
_tfs_z, _mirostat_mode, _mirostat_eta,
_mirostat_tau, _role, _max_tokens):
# ========== 初始化输出模版 ==========
t_bot = chat_template(_role)
completion_tokens = [] # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
history = ''
# ========== 流式输出 ==========
for token in model.generate_t(
tokens=t_bot,
n_keep=_n_keep,
n_discard=_n_discard,
im_start=chat_template.im_start_token,
top_k=_top_k,
top_p=_top_p,
min_p=_min_p,
typical_p=_typical_p,
temp=_temperature,
repeat_penalty=_repeat_penalty,
repeat_last_n=_repeat_last_n,
frequency_penalty=_frequency_penalty,
presence_penalty=_presence_penalty,
tfs_z=_tfs_z,
mirostat_mode=_mirostat_mode,
mirostat_tau=_mirostat_tau,
mirostat_eta=_mirostat_eta,
):
if token in chat_template.eos or token == chat_template.nlnl:
t_bot.extend(completion_tokens)
print('token in eos', token)
break
completion_tokens.append(token)
all_text = model.str_detokenize(completion_tokens)
if not all_text:
continue
t_bot.extend(completion_tokens)
history += all_text
yield history
if token in chat_template.onenl:
# ========== 移除末尾的换行符 ==========
if t_bot[-2] in chat_template.onenl:
model.venv_pop_token()
break
if t_bot[-2] in chat_template.onerl and t_bot[-3] in chat_template.onenl:
model.venv_pop_token()
break
if history[-2:] == '\n\n': # 各种 'x\n\n' 的token,比如'。\n\n'
print('t_bot[-4:]', t_bot[-4:], repr(model.str_detokenize(t_bot[-4:])),
repr(model.str_detokenize(t_bot[-1:])))
break
if len(t_bot) > _max_tokens:
break
completion_tokens = []
# ========== 查看末尾的换行符 ==========
print('history', repr(history))
# ========== 给 kv_cache 加上输出结束符 ==========
model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard)
t_bot.extend(chat_template.im_end_nl)
cfg['btn_com'] = btn_com
def btn_start_or_finish(finish):
tmp = gr.update(interactive=finish)
def _inner():
with lock:
if cfg['session_active'] != finish:
raise RuntimeError
cfg['session_active'] = not cfg['session_active']
return tmp, tmp, tmp, tmp
return _inner
btn_start_or_finish_outputs = [cfg['btn_submit'], cfg['btn_vo'],
cfg['btn_suggest'], cfg['btn_retry']]
cfg['btn_start'] = {
'fn': btn_start_or_finish(False),
'outputs': btn_start_or_finish_outputs
}
cfg['btn_finish'] = {
'fn': btn_start_or_finish(True),
'outputs': btn_start_or_finish_outputs
}
cfg['setting'] = [cfg[x] for x in ('setting_n_keep', 'setting_n_discard',
'setting_temperature', 'setting_repeat_penalty', 'setting_frequency_penalty',
'setting_presence_penalty', 'setting_repeat_last_n', 'setting_top_k',
'setting_top_p', 'setting_min_p', 'setting_typical_p',
'setting_tfs_z', 'setting_mirostat_mode', 'setting_mirostat_eta',
'setting_mirostat_tau', 'role_usr', 'role_char',
'rag', 'setting_max_tokens')]