Spaces:
Sleeping
Sleeping
import numpy as np | |
def init(cfg): | |
chat_template = cfg['chat_template'] | |
model = cfg['model'] | |
s_info = cfg['s_info'] | |
lock = cfg['session_lock'] | |
# ========== 预处理 key、desc ========== | |
def str_tokenize(s): | |
s = model.tokenize((chat_template.nl + s).encode('utf-8'), add_bos=False, special=False) | |
if s[0] in chat_template.onenl: | |
return s[1:] | |
else: | |
return s | |
text_format = cfg['text_format'] | |
for x in cfg['btn_status_bar_list']: | |
x['key'] = text_format(x['key'], | |
char=cfg['role_char'].value, | |
user=cfg['role_usr'].value) | |
x['key_t'] = str_tokenize(x['key']) | |
x['desc'] = text_format(x['desc'], | |
char=cfg['role_char'].value, | |
user=cfg['role_usr'].value) | |
if x['desc']: | |
x['desc_t'] = str_tokenize(x['desc']) | |
# ========== 预处理 构造函数 mask ========== | |
def btn_status_bar_fn_mask(): | |
_shape1d = model.scores.shape[-1] | |
mask = np.full((_shape1d,), -np.inf, dtype=np.single) | |
return mask | |
# ========== 预处理 构造函数 数字 ========== | |
def btn_status_bar_fn_int(unit: str): | |
t_int = str_tokenize('0123456789') | |
assert len(t_int) == 10 | |
fn_int_mask = btn_status_bar_fn_mask() | |
fn_int_mask[chat_template.eos] = 0 | |
fn_int_mask[t_int] = 0 | |
if unit: | |
unit_t = str_tokenize(unit) | |
fn_int_mask[unit_t[0]] = 0 | |
def logits_processor(_input_ids, logits): | |
return logits + fn_int_mask | |
def inner(eval_t, sample_t): | |
retn = [] | |
while True: | |
token = sample_t(logits_processor) | |
# ========== 不是数字就结束 ========== | |
if token in chat_template.eos: | |
break | |
if unit and token == unit_t[0]: | |
break | |
# ========== 是数字就继续 ========== | |
retn.append(token) | |
eval_t([token]) | |
if unit: | |
eval_t(unit_t) # 添加单位 | |
retn.extend(unit_t) | |
return model.str_detokenize(retn) | |
return inner | |
# ========== 预处理 构造函数 集合 ========== | |
def btn_status_bar_fn_set(value): | |
value_t = {_x[0][0]: _x for _x in ((str_tokenize(_y), _y) for _y in value)} | |
fn_set_mask = btn_status_bar_fn_mask() | |
fn_set_mask[list(value_t.keys())] = 0 | |
def logits_processor(_input_ids, logits): | |
return logits + fn_set_mask | |
def inner(eval_t, sample_t): | |
token = sample_t(logits_processor) | |
eval_t(value_t[token][0]) | |
return value_t[token][1] | |
return inner | |
# ========== 预处理 构造函数 字符串 ========== | |
def btn_status_bar_fn_str(): | |
def inner(eval_t, sample_t): | |
retn = [] | |
tmp = '' | |
while True: | |
token = sample_t(None) | |
if token in chat_template.eos: | |
break | |
retn.append(token) | |
tmp = model.str_detokenize(retn) | |
if tmp.endswith('\n') or tmp.endswith('\r'): | |
break | |
# ========== 继续 ========== | |
eval_t([token]) | |
return tmp.strip() | |
return inner | |
# ========== 预处理 value ========== | |
for x in cfg['btn_status_bar_list']: | |
for y in x['combine']: | |
if y['prefix']: | |
y['prefix_t'] = str_tokenize(y['prefix']) | |
if y['type'] == 'int': | |
y['fn'] = btn_status_bar_fn_int(y['unit']) | |
elif y['type'] == 'set': | |
y['fn'] = btn_status_bar_fn_set(y['value']) | |
elif y['type'] == 'str': | |
y['fn'] = btn_status_bar_fn_str() | |
else: | |
pass | |
# ========== 添加分隔标记 ========== | |
for i, x in enumerate(cfg['btn_status_bar_list']): | |
if i == 0: # 跳过第一个 | |
continue | |
x['key_t'] = chat_template.im_end_nl[-1:] + x['key_t'] | |
del x # 避免干扰 | |
del y | |
# print(cfg['btn_status_bar_list']) | |
# ========== 输出状态栏 ========== | |
def btn_status_bar(_n_keep, _n_discard, | |
_temperature, _repeat_penalty, _frequency_penalty, | |
_presence_penalty, _repeat_last_n, _top_k, | |
_top_p, _min_p, _typical_p, | |
_tfs_z, _mirostat_mode, _mirostat_eta, | |
_mirostat_tau, _usr, _char, | |
_rag, _max_tokens): | |
with lock: | |
if not cfg['session_active']: | |
raise RuntimeError | |
if cfg['btn_stop_status']: | |
yield [], model.venv_info | |
return | |
# ========== 临时的eval和sample ========== | |
def eval_t(tokens): | |
return model.eval_t( | |
tokens=tokens, | |
n_keep=_n_keep, | |
n_discard=_n_discard, | |
im_start=chat_template.im_start_token | |
) | |
def sample_t(logits_processor): | |
return model.sample_t( | |
top_k=_top_k, | |
top_p=_top_p, | |
min_p=_min_p, | |
typical_p=_typical_p, | |
temp=_temperature, | |
repeat_penalty=_repeat_penalty, | |
repeat_last_n=_repeat_last_n, | |
frequency_penalty=_frequency_penalty, | |
presence_penalty=_presence_penalty, | |
tfs_z=_tfs_z, | |
mirostat_mode=_mirostat_mode, | |
mirostat_tau=_mirostat_tau, | |
mirostat_eta=_mirostat_eta, | |
logits_processor=logits_processor | |
) | |
# ========== 初始化输出模版 ========== | |
model.venv_create('status') # 创建隔离环境 | |
eval_t(chat_template('状态')) # 开始标记 | |
# ========== 流式输出 ========== | |
df = [] # 清空 | |
for _x in cfg['btn_status_bar_list']: | |
# ========== 属性 ========== | |
df.append([_x['key'], '']) | |
eval_t(_x['key_t']) | |
if _x['desc']: | |
eval_t(_x['desc_t']) | |
yield df, model.venv_info | |
# ========== 值 ========== | |
for _y in _x['combine']: | |
if _y['prefix']: | |
if df[-1][-1]: | |
df[-1][-1] += _y['prefix'] | |
else: | |
df[-1][-1] += _y['prefix'].lstrip(':') | |
eval_t(_y['prefix_t']) | |
df[-1][-1] += _y['fn'](eval_t, sample_t) | |
yield df, model.venv_info | |
eval_t(chat_template.im_end_nl) # 结束标记 | |
# ========== 清理上一次生成的状态栏 ========== | |
model.venv_remove('status', keep_last=1) | |
yield df, model.venv_info | |
cfg['btn_status_bar_fn'] = { | |
'fn': btn_status_bar, | |
'inputs': cfg['setting'], | |
'outputs': [cfg['status_bar'], s_info] | |
} | |
cfg['btn_status_bar_fn'].update(cfg['btn_concurrency']) | |
cfg['btn_status_bar'].click( | |
**cfg['btn_start'] | |
).success( | |
**cfg['btn_status_bar_fn'] | |
).success( | |
**cfg['btn_finish'] | |
) | |