Spaces:
Sleeping
Sleeping
import gradio as gr | |
import threading | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-20M-checkpoint") | |
model = AutoModelForCausalLM.from_pretrained("SmallDoge/Doge-20M-checkpoint", trust_remote_code=True) | |
# Generation configuration | |
generation_config = GenerationConfig( | |
max_new_tokens=100, | |
use_cache=True, | |
do_sample=True, | |
temperature=0.8, | |
top_p=0.9, | |
repetition_penalty=1.0 | |
) | |
def generate_response(conversation): | |
""" | |
Given a conversation (a list of dicts with roles "user"/"assistant" and their contents), | |
this function prepares the prompt, starts generation in a separate thread, and yields | |
the streamed output token by token. | |
""" | |
# Prepare inputs using the chat template from the tokenizer | |
inputs = tokenizer.apply_chat_template( | |
conversation=conversation, | |
tokenize=True, | |
return_tensors="pt" | |
) | |
# Create the streaming iterator. Note: skip_prompt=True omits the prompt from the stream. | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Start generation in a separate thread | |
thread = threading.Thread( | |
target=model.generate, | |
kwargs={ | |
"inputs": inputs, | |
"tokenizer": tokenizer, | |
"generation_config": generation_config, | |
"streamer": streamer | |
} | |
) | |
thread.start() | |
# Yield output tokens as they are generated | |
full_response = "" | |
for token in streamer: | |
full_response += token | |
yield full_response | |
def chat(user_input, history): | |
""" | |
Chat callback for Gradio. | |
- `history` is a list of (user_message, assistant_response) pairs. | |
- We first reassemble the full conversation (as a list of dicts) using our history, | |
then append the latest user input. | |
- We then call generate_response() to stream the model’s reply. | |
- As tokens stream in, we update the conversation history. | |
""" | |
# Rebuild conversation from history for the model prompt | |
conversation = [] | |
for user_msg, bot_msg in history: | |
conversation.append({"role": "user", "content": user_msg}) | |
conversation.append({"role": "assistant", "content": bot_msg}) | |
conversation.append({"role": "user", "content": user_input}) | |
# Create a generator that yields the streamed reply | |
for streamed_reply in generate_response(conversation): | |
# Update history with the new streamed reply (note: only the last bot reply is updating) | |
yield history + [(user_input, streamed_reply)] | |
# Build the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Chat with SmallDoge/Doge-20M-checkpoint") | |
chatbot = gr.Chatbot() # displays the conversation as a list of (user, assistant) pairs | |
with gr.Row(): | |
msg = gr.Textbox(show_label=False, placeholder="Type your message here...") | |
clear = gr.Button("Clear") | |
# When the user submits a message, first update the chat history with an empty reply… | |
def user(message, history): | |
return "", history + [(message, "")] | |
# ...then stream the model response using our chat() generator | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False) \ | |
.then(chat, [msg, chatbot], chatbot) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
# Enable queue for streaming responses and launch the app | |
demo.queue() | |
demo.launch() | |