mattcracker's picture
Update app.py
a493bde verified
# app.py
import gradio as gr
import spaces
from threading import Thread
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
# ------------------------------
# 1. 加载模型与 Tokenizer
# ------------------------------
model_name = "agentica-org/DeepScaleR-1.5B-Preview"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# 如果 tokenizer 没有设置 pad_token_id,则显式指定为 eos_token_id
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# ------------------------------
# 2. 对话历史 -> Prompt 格式
# ------------------------------
def preprocess_messages(history):
"""
将聊天记录拼成一个最简单的 Prompt。
你可以自定义更适合该模型的提示格式或特殊 Token。
"""
prompt = ""
for user_msg, assistant_msg in history:
if user_msg:
prompt += f"User: {user_msg}\n"
if assistant_msg:
prompt += f"Assistant: {assistant_msg}\n"
# 继续生成时,提示 "Assistant:"
prompt += "Assistant: "
return prompt
# ------------------------------
# 3. 预测 / 推理函数
# ------------------------------
@spaces.GPU() # 让 huggingface spaces 调用 GPU
def predict(history, max_length, top_p, temperature):
"""
基于当前的 history 做文本生成。
使用 HF 提供的 TextIteratorStreamer 实现流式生成。
"""
prompt = preprocess_messages(history)
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True, # 自动 padding
truncation=True, # 超长截断
max_length=2048 # 你可根据显存大小或模型上限做调整
)
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs["attention_mask"].to(model.device)
# 流式输出器
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_length, # 新生成的 token 数
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"repetition_penalty": 1.2,
"streamer": streamer,
}
# 在后台线程中执行 generate,主线程循环读取新 token
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# 将最新生成的 token 依次拼接到 history[-1][1]
partial_output = ""
for new_token in streamer:
partial_output += new_token
history[-1][1] = partial_output
yield history
# ------------------------------
# 4. Gradio UI
# ------------------------------
def main():
with gr.Blocks() as demo:
gr.HTML("<h1 align='center'>DeepScaleR-1.5B Chat Demo</h1>")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=2):
user_input = gr.Textbox(
show_label=True,
placeholder="请输入您的问题...",
label="User Input"
)
submitBtn = gr.Button("Submit")
clearBtn = gr.Button("Clear History")
with gr.Column(scale=1):
max_length = gr.Slider(
minimum=0,
maximum=1024, # 可根据需要调大/调小
value=512,
step=1,
label="Max New Tokens",
interactive=True
)
top_p = gr.Slider(
minimum=0,
maximum=1,
value=0.8,
step=0.01,
label="Top P",
interactive=True
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.01,
label="Temperature",
interactive=True
)
# 用户点击 Submit 时,先将输入添加到 history,然后再调用 predict 生成
def user(query, history):
return "", history + [[query, ""]]
submitBtn.click(
fn=user,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot],
queue=False # 不排队
).then(
fn=predict,
inputs=[chatbot, max_length, top_p, temperature],
outputs=chatbot
)
# 清空聊天记录
def clear_history():
return [], []
clearBtn.click(fn=clear_history, inputs=[], outputs=[chatbot, user_input], queue=False)
# 可选:启用队列防止并发冲突
demo.queue(concurrency_count=1)
demo.launch()
# ------------------------------
# 入口
# ------------------------------
if __name__ == "__main__":
main()