File size: 7,577 Bytes
b9cb0bd
 
 
6481b74
 
 
b9cb0bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import numpy as np


def init(cfg):
    chat_template = cfg['chat_template']
    model = cfg['model']
    s_info = cfg['s_info']
    lock = cfg['session_lock']

    # ========== 预处理 key、desc ==========
    def str_tokenize(s):
        s = model.tokenize((chat_template.nl + s).encode('utf-8'), add_bos=False, special=False)
        if s[0] in chat_template.onenl:
            return s[1:]
        else:
            return s

    text_format = cfg['text_format']
    for x in cfg['btn_status_bar_list']:
        x['key'] = text_format(x['key'],
                               char=cfg['role_char'].value,
                               user=cfg['role_usr'].value)
        x['key_t'] = str_tokenize(x['key'])
        x['desc'] = text_format(x['desc'],
                                char=cfg['role_char'].value,
                                user=cfg['role_usr'].value)
        if x['desc']:
            x['desc_t'] = str_tokenize(x['desc'])

    # ========== 预处理 构造函数 mask ==========
    def btn_status_bar_fn_mask():
        _shape1d = model.scores.shape[-1]
        mask = np.full((_shape1d,), -np.inf, dtype=np.single)
        return mask

    # ========== 预处理 构造函数 数字 ==========
    def btn_status_bar_fn_int(unit: str):
        t_int = str_tokenize('0123456789')
        assert len(t_int) == 10
        fn_int_mask = btn_status_bar_fn_mask()
        fn_int_mask[chat_template.eos] = 0
        fn_int_mask[t_int] = 0
        if unit:
            unit_t = str_tokenize(unit)
            fn_int_mask[unit_t[0]] = 0

        def logits_processor(_input_ids, logits):
            return logits + fn_int_mask

        def inner(eval_t, sample_t):
            retn = []
            while True:
                token = sample_t(logits_processor)
                # ========== 不是数字就结束 ==========
                if token in chat_template.eos:
                    break
                if unit and token == unit_t[0]:
                    break
                # ========== 是数字就继续 ==========
                retn.append(token)
                eval_t([token])

            if unit:
                eval_t(unit_t)  # 添加单位
                retn.extend(unit_t)

            return model.str_detokenize(retn)

        return inner

    # ========== 预处理 构造函数 集合 ==========
    def btn_status_bar_fn_set(value):
        value_t = {_x[0][0]: _x for _x in ((str_tokenize(_y), _y) for _y in value)}
        fn_set_mask = btn_status_bar_fn_mask()
        fn_set_mask[list(value_t.keys())] = 0

        def logits_processor(_input_ids, logits):
            return logits + fn_set_mask

        def inner(eval_t, sample_t):
            token = sample_t(logits_processor)
            eval_t(value_t[token][0])
            return value_t[token][1]

        return inner

    # ========== 预处理 构造函数 字符串 ==========
    def btn_status_bar_fn_str():
        def inner(eval_t, sample_t):
            retn = []
            tmp = ''
            while True:
                token = sample_t(None)
                if token in chat_template.eos:
                    break
                retn.append(token)
                tmp = model.str_detokenize(retn)
                if tmp.endswith('\n') or tmp.endswith('\r'):
                    break
                # ========== 继续 ==========
                eval_t([token])
            return tmp.strip()

        return inner

    # ========== 预处理 value ==========
    for x in cfg['btn_status_bar_list']:
        for y in x['combine']:
            if y['prefix']:
                y['prefix_t'] = str_tokenize(y['prefix'])

            if y['type'] == 'int':
                y['fn'] = btn_status_bar_fn_int(y['unit'])
            elif y['type'] == 'set':
                y['fn'] = btn_status_bar_fn_set(y['value'])
            elif y['type'] == 'str':
                y['fn'] = btn_status_bar_fn_str()
            else:
                pass

    # ========== 添加分隔标记 ==========
    for i, x in enumerate(cfg['btn_status_bar_list']):
        if i == 0:  # 跳过第一个
            continue
        x['key_t'] = chat_template.im_end_nl[-1:] + x['key_t']

    del x  # 避免干扰
    del y

    # print(cfg['btn_status_bar_list'])

    # ========== 输出状态栏 ==========
    def btn_status_bar(_n_keep, _n_discard,
                       _temperature, _repeat_penalty, _frequency_penalty,
                       _presence_penalty, _repeat_last_n, _top_k,
                       _top_p, _min_p, _typical_p,
                       _tfs_z, _mirostat_mode, _mirostat_eta,
                       _mirostat_tau, _usr, _char,
                       _rag, _max_tokens):
        with lock:
            if not cfg['session_active']:
                raise RuntimeError
            if cfg['btn_stop_status']:
                yield [], model.venv_info
                return

            # ========== 临时的eval和sample ==========
            def eval_t(tokens):
                return model.eval_t(
                    tokens=tokens,
                    n_keep=_n_keep,
                    n_discard=_n_discard,
                    im_start=chat_template.im_start_token
                )

            def sample_t(logits_processor):
                return model.sample_t(
                    top_k=_top_k,
                    top_p=_top_p,
                    min_p=_min_p,
                    typical_p=_typical_p,
                    temp=_temperature,
                    repeat_penalty=_repeat_penalty,
                    repeat_last_n=_repeat_last_n,
                    frequency_penalty=_frequency_penalty,
                    presence_penalty=_presence_penalty,
                    tfs_z=_tfs_z,
                    mirostat_mode=_mirostat_mode,
                    mirostat_tau=_mirostat_tau,
                    mirostat_eta=_mirostat_eta,
                    logits_processor=logits_processor
                )

            # ========== 初始化输出模版 ==========
            model.venv_create('status')  # 创建隔离环境
            eval_t(chat_template('状态'))  # 开始标记
            # ========== 流式输出 ==========
            df = []  # 清空
            for _x in cfg['btn_status_bar_list']:
                # ========== 属性 ==========
                df.append([_x['key'], ''])
                eval_t(_x['key_t'])
                if _x['desc']:
                    eval_t(_x['desc_t'])
                yield df, model.venv_info
                # ========== 值 ==========
                for _y in _x['combine']:
                    if _y['prefix']:
                        if df[-1][-1]:
                            df[-1][-1] += _y['prefix']
                        else:
                            df[-1][-1] += _y['prefix'].lstrip(':')
                        eval_t(_y['prefix_t'])
                    df[-1][-1] += _y['fn'](eval_t, sample_t)
                    yield df, model.venv_info
            eval_t(chat_template.im_end_nl)  # 结束标记
            # ========== 清理上一次生成的状态栏 ==========
            model.venv_remove('status', keep_last=1)
            yield df, model.venv_info

    cfg['btn_status_bar_fn'] = {
        'fn': btn_status_bar,
        'inputs': cfg['setting'],
        'outputs': [cfg['status_bar'], s_info]
    }
    cfg['btn_status_bar_fn'].update(cfg['btn_concurrency'])

    cfg['btn_status_bar'].click(
        **cfg['btn_start']
    ).success(
        **cfg['btn_status_bar_fn']
    ).success(
        **cfg['btn_finish']
    )