File size: 13,718 Bytes
a95340f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd33231
a95340f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350fb3a
a95340f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e66eacf
 
a95340f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# pylint: disable=line-too-long, broad-exception-caught, invalid-name, missing-function-docstring, too-many-instance-attributes, missing-class-docstring
# ruff: noqa: E501
import os # 导入os模块
import platform # 导入platform模块
import random # 导入random模块
import time # 导入time模块
from dataclasses import asdict, dataclass # 从dataclasses模块中导入asdict和dataclass
from pathlib import Path # 从pathlib模块中导入Path类

# from types import SimpleNamespace # 从types模块中导入SimpleNamespace类,但未使用
import gradio as gr #导入gradio模块并起别名gr
import psutil #导入psutil模块  
import getpass #导入 getpass模块
from about_time import about_time # 从about_time模块中导入about_time函数
from ctransformers import AutoModelForCausalLM # 从ctransformers模块中导入AutoModelForCausalLM类
from dl_hf_model import dl_hf_model # 从dl_hf_model模块中导入dl_hf_model函数
from loguru import logger # 从loguru模块中导入logger




filename_list = [ # 定义文件名列表
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q2_K.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q3_K_L.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q3_K_M.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q3_K_S.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_0.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_1.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_K_M.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_K_S.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q5_0.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q5_1.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q5_K_M.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q5_K_S.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q6_K.bin",
    "Wizard-Vicuna-7B-Uncensored.ggmlv3.q8_0.bin",
]

URL = "https://huggingface.co/TheBloke/Wizard-Vicuna-7B-Uncensored-GGML/raw/main/Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_K_M.bin"  # 4.05G

#url = "https://huggingface.co/savvamadar/ggml-gpt4all-j-v1.3-groovy/blob/main/ggml-gpt4all-j-v1.3-groovy.bin"
url = "https://huggingface.co/TheBloke/Llama-2-13B-GGML/blob/main/llama-2-13b.ggmlv3.q4_K_S.bin"  # 7.37G
url = "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q3_K_L.bin"
url = "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q3_K_L.bin"  # 6.93G
url = "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q3_K_L.binhttps://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q4_K_M.bin"  # 7.87G

url = "https://huggingface.co/localmodels/Llama-2-13B-Chat-ggml/blob/main/llama-2-13b-chat.ggmlv3.q4_K_S.bin"  # 7.37G

_ = ( # 定义一个判断是否在特定环境的标志
    "golay" in platform.node() 
    or "okteto" in platform.node()
    or Path("/kaggle").exists()
    # or psutil.cpu_count(logical=False) < 4
    or 1  # run 7b in hf
) 

if _: # 如果在特定环境中
    url = "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q2_K.bin"
    # url = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q2_K.bin"  # 2.87G
    # url = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q4_K_M.bin"  # 2.87G


prompt_template = """[INST] <<SYS>>
You are a cute kitten and I am your master.
<</SYS>>

{question} [/INST]
"""


_ = psutil.cpu_count(logical=False) - 1 # 获取CPU物理核心数减1
cpu_count: int = int(_) if _ else 1 # 如果上一步结果小于0则为1
logger.debug(f"{cpu_count=}") # 打印CPU核心数

LLM = None # 声明LLM变量

try:
    model_loc, file_size = dl_hf_model(url) # 从url下载模型到本地
except Exception as exc_:
    logger.error(exc_) # 打印错误
    raise SystemExit(1) from exc_ # 如果下载失败则退出

LLM = AutoModelForCausalLM.from_pretrained( # 初始化LLM模型
    model_loc,  
    model_type="llama",
    # threads=cpu_count,
)

logger.info(f"done load llm {model_loc=} {file_size=}G") # 打印加载模型信息

os.environ["TZ"] = "Asia/Shanghai" # 设置时区为上海
try:
    time.tzset()  # type: ignore # pylint: disable=no-member # 尝试应用时区设置
except Exception:
    # Windows 
    logger.warning("Windows, cant run time.tzset()") # windows不支持tzset打印提示

_ = """
ns = SimpleNamespace(
    response="",
    generator=(_ for _ in []),
)
# """

@dataclass # 定义数据类
class GenerationConfig:
    temperature: float = 0.7 
    top_k: int = 50
    top_p: float = 0.9
    repetition_penalty: float = 1.0
    max_new_tokens: int = 512
    seed: int = 42
    reset: bool = False
    stream: bool = True
    # threads: int = cpu_count
    # stop: list[str] = field(default_factory=lambda: [stop_string])


def generate( # 定义生成函数
    question: str,
    llm=LLM,
    config: GenerationConfig = GenerationConfig(),
):
    """Run model inference, will return a Generator if streaming is true."""
    # _ = prompt_template.format(question=question)
    # print(_)

    prompt = prompt_template.format(question=question) # 填充prompt

    return llm( # 调用LLM模型
        prompt,
        **asdict(config),
    )


logger.debug(f"{asdict(GenerationConfig())=}") # 打印默认生成配置


def user(user_message, history): # 定义user函数处理用户输入
    # return user_message, history + [[user_message, None]]
    history.append([user_message, None]) # 在history中追加用户输入
    return user_message, history  # keep user_message


def user1(user_message, history): # 定义user1函数处理用户输入
    # return user_message, history + [[user_message, None]]
    history.append([user_message, None]) # 在history中追加用户输入
    return "", history  # clear user_message


def bot_(history): # 定义bot_函数生成回复
    user_message = history[-1][0]
    resp = random.choice(["How are you?", "I love you", "I'm very hungry"])
    bot_message = user_message + ": " + resp
    history[-1][1] = ""
    for character in bot_message:
        history[-1][1] += character
        time.sleep(0.02)
        yield history

    history[-1][1] = resp
    yield history


def bot(history): # 定义bot函数生成回复
    user_message = history[-1][0]
    response = []

    logger.debug(f"{user_message=}")

    with about_time() as atime:  # type: ignore # 测量生成用时
        flag = 1
        prefix = ""
        then = time.time()

        logger.debug("about to generate")

        config = GenerationConfig(reset=True) # 配置生成参数
        for elm in generate(user_message, config=config): # 生成回复
            if flag == 1:
                logger.debug("in the loop")
                prefix = f"({time.time() - then:.2f}s) "
                flag = 0
                print(prefix, end="", flush=True)
                logger.debug(f"{prefix=}")
            print(elm, end="", flush=True)
            # logger.debug(f"{elm}")

            response.append(elm)
            history[-1][1] = prefix + "".join(response) # 拼接前缀和生成内容到回复中
            yield history

    _ = (
        f"(time elapsed: {atime.duration_human}, "  # type: ignore # 生成用时信息
        f"{atime.duration/len(''.join(response)):.2f}s/char)"  # type: ignore
    )

    history[-1][1] = "".join(response)  + f"\n{_}" # 拼接生成内容和用时信息为最终回复
    yield history


def predict_api(prompt): # 定义预测API函数
    logger.debug(f"{prompt=}")
    try:
        # user_prompt = prompt
        config = GenerationConfig( # 配置生成参数
            temperature=0.7,  
            top_k=10,
            top_p=0.9,
            repetition_penalty=1.0,
            max_new_tokens=512,  # adjust as needed
            seed=42,
            reset=True,  # reset history (cache)
            stream=False,
            # threads=cpu_count,
            # stop=prompt_prefix[1:2],
        )

        response = generate( # 生成回复
            prompt,
            config=config,
        )

        logger.debug(f"api: {response=}")
    except Exception as exc:
        logger.error(exc)
        response = f"{exc=}"
    # bot = {"inputs": [response]}
    # bot = [(prompt, response)]

    return response


css = """ # 定义css样式
    .importantButton {
        background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
        border: none !important;
    }
    .importantButton:hover {
        background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
        border: none !important;
    }
    .disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;}
    .xsmall {font-size: x-small;}
"""
etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
examples_list = [ # 定义示例输入列表
    ["Hi, what are you doing?"],
    [
        "Hello."
    ]
]
logger.info("start block")

with gr.Blocks( # 使用gradio构建界面
    title=f"{Path(model_loc).name}",  
    theme=gr.themes.Soft(text_size="sm", spacing_size="sm"),
    css=css,
) as block:
    # buff_var = gr.State("")
    with gr.Accordion("🎈 Info", open=False): # 折叠面板显示模型信息
        # gr.HTML(
        #     """<center><a href="https://huggingface.co/spaces/mikeee/mpt-30b-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate"></a> and spin a CPU UPGRADE to avoid the queue</center>"""
        # )
        gr.Markdown(
            f"""<h5><center>{Path(model_loc).name}</center></h4>
            超级小猫使用LLaMA-2-13b-chat,调用16G的CPU运行,速度比较慢,请见谅。模型数据主要为英文,建议使用英文进行问答""",
            elem_classes="xsmall",
        )

    # chatbot = gr.Chatbot().style(height=700)  # 500
    chatbot = gr.Chatbot(height=500) # 聊天界面

    # buff = gr.Textbox(show_label=False, visible=True)

    with gr.Row(): # 输入区域
        with gr.Column(scale=5): 
            msg = gr.Textbox(
                label="Chat Message Box",
                placeholder="Ask me anything (press Shift+Enter or click Submit to send)",
                show_label=False,
                # container=False,
                lines=6,
                max_lines=30,
                show_copy_button=True,
                # ).style(container=False)
            )
        with gr.Column(scale=1, min_width=50):
            with gr.Row():
                submit = gr.Button("发送", elem_classes="xsmall") # 提交按钮
                stop = gr.Button("停止", visible=True) # 停止按钮
                clear = gr.Button("清除历史会话", visible=True) # 清空历史按钮

    with gr.Accordion("Example Inputs", open=True): # 示例输入面板
        examples = gr.Examples(
            examples=examples_list,
            inputs=[msg],
            examples_per_page=40,
        )

    # with gr.Row():
    with gr.Accordion("Disclaimer", open=False): # 免责声明面板
        _ = Path(model_loc).name
        gr.Markdown(
           "免责声明:超级小猫(POWERED BY LLAMA 2) 可能会产生与事实不符的输出,不应依赖它来产生 "
            "事实准确的信息。超级小猫(POWERED BY LLAMA 2) 是在各种公共数据集上进行训练的;虽然已尽 "
            "已尽力清理预训练数据,但该模型仍有可能产生不良内容,"
            "有偏见或其他冒犯性的输出",
            elem_classes=["disclaimer"],
        )

    msg_submit_event = msg.submit( # 提交事件绑定user函数和bot函数
        # fn=conversation.user_turn,
        fn=user,
        inputs=[msg, chatbot],
        outputs=[msg, chatbot],
        queue=True,
        show_progress="full",
        # api_name=None,
    ).then(bot, chatbot, chatbot, queue=True) 
    submit_click_event = submit.click( # 点击提交按钮事件,绑定user1函数清空输入和bot函数
        # fn=lambda x, y: ("",) + user(x, y)[1:],  # clear msg
        fn=user1,  # clear msg
        inputs=[msg, chatbot],
        outputs=[msg, chatbot],
        queue=True,
        # queue=False,
        show_progress="full",
        # api_name=None,
    ).then(bot, chatbot, chatbot, queue=True)
    stop.click( # 点击停止按钮清空队列
        fn=None,
        inputs=None,
        outputs=None,
        cancels=[msg_submit_event, submit_click_event],
        queue=False,
    )
    clear.click(lambda: None, None, chatbot, queue=False) # 点击清空历史按钮
    
    with gr.Accordion("For Chat/Translation API", open=False, visible=False): # API调用面板
        input_text = gr.Text()
        api_btn = gr.Button("Go", variant="primary")
        out_text = gr.Text()

    api_btn.click( # 绑定API调用逻辑
        predict_api,
        input_text,
        out_text,
        api_name="api",
    )

    # block.load(update_buff, [], buff, every=1)
    # block.load(update_buff, [buff_var], [buff_var, buff], every=1)

# concurrency_count=5, max_size=20
# max_size=36, concurrency_count=14
# CPU cpu_count=2 16G, model 7G
# CPU UPGRADE cpu_count=8 32G, model 7G

# does not work
_ = """  
# _ = int(psutil.virtual_memory().total / 10**9 // file_size - 1)
# concurrency_count = max(_, 1)
if psutil.cpu_count(logical=False) >= 8:
    # concurrency_count = max(int(32 / file_size) - 1, 1) 
else:
    # concurrency_count = max(int(16 / file_size) - 1, 1)
# """

concurrency_count = 1 # 并发数设置为1
logger.info(f"{concurrency_count=}") 

block.queue(concurrency_count=concurrency_count, max_size=5).launch(debug=True) # 启动服务器