Spaces:
Sleeping
Sleeping
File size: 13,717 Bytes
a95340f 2fb7367 a95340f 350fb3a a95340f e66eacf a95340f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 |
# pylint: disable=line-too-long, broad-exception-caught, invalid-name, missing-function-docstring, too-many-instance-attributes, missing-class-docstring
# ruff: noqa: E501
import os # 导入os模块
import platform # 导入platform模块
import random # 导入random模块
import time # 导入time模块
from dataclasses import asdict, dataclass # 从dataclasses模块中导入asdict和dataclass
from pathlib import Path # 从pathlib模块中导入Path类
# from types import SimpleNamespace # 从types模块中导入SimpleNamespace类,但未使用
import gradio as gr #导入gradio模块并起别名gr
import psutil #导入psutil模块
import getpass #导入 getpass模块
from about_time import about_time # 从about_time模块中导入about_time函数
from ctransformers import AutoModelForCausalLM # 从ctransformers模块中导入AutoModelForCausalLM类
from dl_hf_model import dl_hf_model # 从dl_hf_model模块中导入dl_hf_model函数
from loguru import logger # 从loguru模块中导入logger
filename_list = [ # 定义文件名列表
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q2_K.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q3_K_L.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q3_K_M.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q3_K_S.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_0.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_1.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_K_M.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_K_S.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q5_0.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q5_1.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q5_K_M.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q5_K_S.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q6_K.bin",
"Wizard-Vicuna-7B-Uncensored.ggmlv3.q8_0.bin",
]
URL = "https://huggingface.co/TheBloke/Wizard-Vicuna-7B-Uncensored-GGML/raw/main/Wizard-Vicuna-7B-Uncensored.ggmlv3.q4_K_M.bin" # 4.05G
#url = "https://huggingface.co/savvamadar/ggml-gpt4all-j-v1.3-groovy/blob/main/ggml-gpt4all-j-v1.3-groovy.bin"
url = "https://huggingface.co/TheBloke/Llama-2-13B-GGML/blob/main/llama-2-13b.ggmlv3.q4_K_S.bin" # 7.37G
url = "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q3_K_L.bin"
url = "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q3_K_L.bin" # 6.93G
url = "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q3_K_L.binhttps://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q4_K_M.bin" # 7.87G
url = "https://huggingface.co/localmodels/Llama-2-13B-Chat-ggml/blob/main/llama-2-13b-chat.ggmlv3.q4_K_S.bin" # 7.37G
_ = ( # 定义一个判断是否在特定环境的标志
"golay" in platform.node()
or "okteto" in platform.node()
or Path("/kaggle").exists()
# or psutil.cpu_count(logical=False) < 4
or 1 # run 7b in hf
)
if _: # 如果在特定环境中
url = "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/blob/main/llama-2-13b-chat.ggmlv3.q2_K.bin"
# url = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q2_K.bin" # 2.87G
# url = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q4_K_M.bin" # 2.87G
prompt_template = """[INST] <<SYS>>
You are a cute kitten and I am your owner.
<</SYS>>
{question} [/INST]
"""
_ = psutil.cpu_count(logical=False) - 1 # 获取CPU物理核心数减1
cpu_count: int = int(_) if _ else 1 # 如果上一步结果小于0则为1
logger.debug(f"{cpu_count=}") # 打印CPU核心数
LLM = None # 声明LLM变量
try:
model_loc, file_size = dl_hf_model(url) # 从url下载模型到本地
except Exception as exc_:
logger.error(exc_) # 打印错误
raise SystemExit(1) from exc_ # 如果下载失败则退出
LLM = AutoModelForCausalLM.from_pretrained( # 初始化LLM模型
model_loc,
model_type="llama",
# threads=cpu_count,
)
logger.info(f"done load llm {model_loc=} {file_size=}G") # 打印加载模型信息
os.environ["TZ"] = "Asia/Shanghai" # 设置时区为上海
try:
time.tzset() # type: ignore # pylint: disable=no-member # 尝试应用时区设置
except Exception:
# Windows
logger.warning("Windows, cant run time.tzset()") # windows不支持tzset打印提示
_ = """
ns = SimpleNamespace(
response="",
generator=(_ for _ in []),
)
# """
@dataclass # 定义数据类
class GenerationConfig:
temperature: float = 0.7
top_k: int = 50
top_p: float = 0.9
repetition_penalty: float = 1.0
max_new_tokens: int = 512
seed: int = 42
reset: bool = False
stream: bool = True
# threads: int = cpu_count
# stop: list[str] = field(default_factory=lambda: [stop_string])
def generate( # 定义生成函数
question: str,
llm=LLM,
config: GenerationConfig = GenerationConfig(),
):
"""Run model inference, will return a Generator if streaming is true."""
# _ = prompt_template.format(question=question)
# print(_)
prompt = prompt_template.format(question=question) # 填充prompt
return llm( # 调用LLM模型
prompt,
**asdict(config),
)
logger.debug(f"{asdict(GenerationConfig())=}") # 打印默认生成配置
def user(user_message, history): # 定义user函数处理用户输入
# return user_message, history + [[user_message, None]]
history.append([user_message, None]) # 在history中追加用户输入
return user_message, history # keep user_message
def user1(user_message, history): # 定义user1函数处理用户输入
# return user_message, history + [[user_message, None]]
history.append([user_message, None]) # 在history中追加用户输入
return "", history # clear user_message
def bot_(history): # 定义bot_函数生成回复
user_message = history[-1][0]
resp = random.choice(["How are you?", "I love you", "I'm very hungry"])
bot_message = user_message + ": " + resp
history[-1][1] = ""
for character in bot_message:
history[-1][1] += character
time.sleep(0.02)
yield history
history[-1][1] = resp
yield history
def bot(history): # 定义bot函数生成回复
user_message = history[-1][0]
response = []
logger.debug(f"{user_message=}")
with about_time() as atime: # type: ignore # 测量生成用时
flag = 1
prefix = ""
then = time.time()
logger.debug("about to generate")
config = GenerationConfig(reset=True) # 配置生成参数
for elm in generate(user_message, config=config): # 生成回复
if flag == 1:
logger.debug("in the loop")
prefix = f"({time.time() - then:.2f}s) "
flag = 0
print(prefix, end="", flush=True)
logger.debug(f"{prefix=}")
print(elm, end="", flush=True)
# logger.debug(f"{elm}")
response.append(elm)
history[-1][1] = prefix + "".join(response) # 拼接前缀和生成内容到回复中
yield history
_ = (
f"(time elapsed: {atime.duration_human}, " # type: ignore # 生成用时信息
f"{atime.duration/len(''.join(response)):.2f}s/char)" # type: ignore
)
history[-1][1] = "".join(response) + f"\n{_}" # 拼接生成内容和用时信息为最终回复
yield history
def predict_api(prompt): # 定义预测API函数
logger.debug(f"{prompt=}")
try:
# user_prompt = prompt
config = GenerationConfig( # 配置生成参数
temperature=0.7,
top_k=10,
top_p=0.9,
repetition_penalty=1.0,
max_new_tokens=512, # adjust as needed
seed=42,
reset=True, # reset history (cache)
stream=False,
# threads=cpu_count,
# stop=prompt_prefix[1:2],
)
response = generate( # 生成回复
prompt,
config=config,
)
logger.debug(f"api: {response=}")
except Exception as exc:
logger.error(exc)
response = f"{exc=}"
# bot = {"inputs": [response]}
# bot = [(prompt, response)]
return response
css = """ # 定义css样式
.importantButton {
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
border: none !important;
}
.importantButton:hover {
background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
border: none !important;
}
.disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;}
.xsmall {font-size: x-small;}
"""
etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
examples_list = [ # 定义示例输入列表
["Hi, what are you doing?"],
[
"Hello."
]
]
logger.info("start block")
with gr.Blocks( # 使用gradio构建界面
title=f"{Path(model_loc).name}",
theme=gr.themes.Soft(text_size="sm", spacing_size="sm"),
css=css,
) as block:
# buff_var = gr.State("")
with gr.Accordion("🎈 Info", open=False): # 折叠面板显示模型信息
# gr.HTML(
# """<center><a href="https://huggingface.co/spaces/mikeee/mpt-30b-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate"></a> and spin a CPU UPGRADE to avoid the queue</center>"""
# )
gr.Markdown(
f"""<h5><center>{Path(model_loc).name}</center></h4>
超级小猫使用LLaMA-2-13b-chat,调用16G的CPU运行,速度比较慢,请见谅。模型数据主要为英文,建议使用英文进行问答""",
elem_classes="xsmall",
)
# chatbot = gr.Chatbot().style(height=700) # 500
chatbot = gr.Chatbot(height=500) # 聊天界面
# buff = gr.Textbox(show_label=False, visible=True)
with gr.Row(): # 输入区域
with gr.Column(scale=5):
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Ask me anything (press Shift+Enter or click Submit to send)",
show_label=False,
# container=False,
lines=6,
max_lines=30,
show_copy_button=True,
# ).style(container=False)
)
with gr.Column(scale=1, min_width=50):
with gr.Row():
submit = gr.Button("发送", elem_classes="xsmall") # 提交按钮
stop = gr.Button("停止", visible=True) # 停止按钮
clear = gr.Button("清除历史会话", visible=True) # 清空历史按钮
with gr.Accordion("Example Inputs", open=True): # 示例输入面板
examples = gr.Examples(
examples=examples_list,
inputs=[msg],
examples_per_page=40,
)
# with gr.Row():
with gr.Accordion("Disclaimer", open=False): # 免责声明面板
_ = Path(model_loc).name
gr.Markdown(
"免责声明:超级小猫(POWERED BY LLAMA 2) 可能会产生与事实不符的输出,不应依赖它来产生 "
"事实准确的信息。超级小猫(POWERED BY LLAMA 2) 是在各种公共数据集上进行训练的;虽然已尽 "
"已尽力清理预训练数据,但该模型仍有可能产生不良内容,"
"有偏见或其他冒犯性的输出",
elem_classes=["disclaimer"],
)
msg_submit_event = msg.submit( # 提交事件绑定user函数和bot函数
# fn=conversation.user_turn,
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
queue=True,
show_progress="full",
# api_name=None,
).then(bot, chatbot, chatbot, queue=True)
submit_click_event = submit.click( # 点击提交按钮事件,绑定user1函数清空输入和bot函数
# fn=lambda x, y: ("",) + user(x, y)[1:], # clear msg
fn=user1, # clear msg
inputs=[msg, chatbot],
outputs=[msg, chatbot],
queue=True,
# queue=False,
show_progress="full",
# api_name=None,
).then(bot, chatbot, chatbot, queue=True)
stop.click( # 点击停止按钮清空队列
fn=None,
inputs=None,
outputs=None,
cancels=[msg_submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, chatbot, queue=False) # 点击清空历史按钮
with gr.Accordion("For Chat/Translation API", open=False, visible=False): # API调用面板
input_text = gr.Text()
api_btn = gr.Button("Go", variant="primary")
out_text = gr.Text()
api_btn.click( # 绑定API调用逻辑
predict_api,
input_text,
out_text,
api_name="api",
)
# block.load(update_buff, [], buff, every=1)
# block.load(update_buff, [buff_var], [buff_var, buff], every=1)
# concurrency_count=5, max_size=20
# max_size=36, concurrency_count=14
# CPU cpu_count=2 16G, model 7G
# CPU UPGRADE cpu_count=8 32G, model 7G
# does not work
_ = """
# _ = int(psutil.virtual_memory().total / 10**9 // file_size - 1)
# concurrency_count = max(_, 1)
if psutil.cpu_count(logical=False) >= 8:
# concurrency_count = max(int(32 / file_size) - 1, 1)
else:
# concurrency_count = max(int(16 / file_size) - 1, 1)
# """
concurrency_count = 1 # 并发数设置为1
logger.info(f"{concurrency_count=}")
block.queue(concurrency_count=concurrency_count, max_size=5).launch(debug=True) # 启动服务器
|