Spaces:
Runtime error
Runtime error
# app.py | |
import gradio as gr | |
import spaces | |
from threading import Thread | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextIteratorStreamer, | |
) | |
# ------------------------------ | |
# 1. 加载模型与 Tokenizer | |
# ------------------------------ | |
model_name = "agentica-org/DeepScaleR-1.5B-Preview" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
# 如果 tokenizer 没有设置 pad_token_id,则显式指定为 eos_token_id | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
# ------------------------------ | |
# 2. 对话历史 -> Prompt 格式 | |
# ------------------------------ | |
def preprocess_messages(history): | |
""" | |
将聊天记录拼成一个最简单的 Prompt。 | |
你可以自定义更适合该模型的提示格式或特殊 Token。 | |
""" | |
prompt = "" | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
prompt += f"User: {user_msg}\n" | |
if assistant_msg: | |
prompt += f"Assistant: {assistant_msg}\n" | |
# 继续生成时,提示 "Assistant:" | |
prompt += "Assistant: " | |
return prompt | |
# ------------------------------ | |
# 3. 预测 / 推理函数 | |
# ------------------------------ | |
# 让 huggingface spaces 调用 GPU | |
def predict(history, max_length, top_p, temperature): | |
""" | |
基于当前的 history 做文本生成。 | |
使用 HF 提供的 TextIteratorStreamer 实现流式生成。 | |
""" | |
prompt = preprocess_messages(history) | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, # 自动 padding | |
truncation=True, # 超长截断 | |
max_length=2048 # 你可根据显存大小或模型上限做调整 | |
) | |
input_ids = inputs["input_ids"].to(model.device) | |
attention_mask = inputs["attention_mask"].to(model.device) | |
# 流式输出器 | |
streamer = TextIteratorStreamer( | |
tokenizer=tokenizer, | |
timeout=60, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"max_new_tokens": max_length, # 新生成的 token 数 | |
"do_sample": True, | |
"top_p": top_p, | |
"temperature": temperature, | |
"repetition_penalty": 1.2, | |
"streamer": streamer, | |
} | |
# 在后台线程中执行 generate,主线程循环读取新 token | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# 将最新生成的 token 依次拼接到 history[-1][1] | |
partial_output = "" | |
for new_token in streamer: | |
partial_output += new_token | |
history[-1][1] = partial_output | |
yield history | |
# ------------------------------ | |
# 4. Gradio UI | |
# ------------------------------ | |
def main(): | |
with gr.Blocks() as demo: | |
gr.HTML("<h1 align='center'>DeepScaleR-1.5B Chat Demo</h1>") | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=2): | |
user_input = gr.Textbox( | |
show_label=True, | |
placeholder="请输入您的问题...", | |
label="User Input" | |
) | |
submitBtn = gr.Button("Submit") | |
clearBtn = gr.Button("Clear History") | |
with gr.Column(scale=1): | |
max_length = gr.Slider( | |
minimum=0, | |
maximum=1024, # 可根据需要调大/调小 | |
value=512, | |
step=1, | |
label="Max New Tokens", | |
interactive=True | |
) | |
top_p = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.8, | |
step=0.01, | |
label="Top P", | |
interactive=True | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.7, | |
step=0.01, | |
label="Temperature", | |
interactive=True | |
) | |
# 用户点击 Submit 时,先将输入添加到 history,然后再调用 predict 生成 | |
def user(query, history): | |
return "", history + [[query, ""]] | |
submitBtn.click( | |
fn=user, | |
inputs=[user_input, chatbot], | |
outputs=[user_input, chatbot], | |
queue=False # 不排队 | |
).then( | |
fn=predict, | |
inputs=[chatbot, max_length, top_p, temperature], | |
outputs=chatbot | |
) | |
# 清空聊天记录 | |
def clear_history(): | |
return [], [] | |
clearBtn.click(fn=clear_history, inputs=[], outputs=[chatbot, user_input], queue=False) | |
# 可选:启用队列防止并发冲突 | |
demo.queue(concurrency_count=1) | |
demo.launch() | |
# ------------------------------ | |
# 入口 | |
# ------------------------------ | |
if __name__ == "__main__": | |
main() | |