File size: 3,544 Bytes
5c6e205
 
 
 
 
56a6d2c
 
5c6e205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56a6d2c
5c6e205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import threading
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-20M-checkpoint")
model = AutoModelForCausalLM.from_pretrained("SmallDoge/Doge-20M-checkpoint", trust_remote_code=True)

# Generation configuration
generation_config = GenerationConfig(
    max_new_tokens=100,
    use_cache=True,
    do_sample=True,
    temperature=0.8,
    top_p=0.9,
    repetition_penalty=1.0
)

def generate_response(conversation):
    """
    Given a conversation (a list of dicts with roles "user"/"assistant" and their contents),
    this function prepares the prompt, starts generation in a separate thread, and yields
    the streamed output token by token.
    """
    # Prepare inputs using the chat template from the tokenizer
    inputs = tokenizer.apply_chat_template(
        conversation=conversation,
        tokenize=True,
        return_tensors="pt"
    )
    # Create the streaming iterator. Note: skip_prompt=True omits the prompt from the stream.
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    # Start generation in a separate thread
    thread = threading.Thread(
        target=model.generate,
        kwargs={
            "inputs": inputs,
            "tokenizer": tokenizer,
            "generation_config": generation_config,
            "streamer": streamer
        }
    )
    thread.start()

    # Yield output tokens as they are generated
    full_response = ""
    for token in streamer:
        full_response += token
        yield full_response

def chat(user_input, history):
    """
    Chat callback for Gradio.
    
    - `history` is a list of (user_message, assistant_response) pairs.
    - We first reassemble the full conversation (as a list of dicts) using our history,
      then append the latest user input.
    - We then call generate_response() to stream the model’s reply.
    - As tokens stream in, we update the conversation history.
    """
    # Rebuild conversation from history for the model prompt
    conversation = []
    for user_msg, bot_msg in history:
        conversation.append({"role": "user", "content": user_msg})
        conversation.append({"role": "assistant", "content": bot_msg})
    conversation.append({"role": "user", "content": user_input})
    
    # Create a generator that yields the streamed reply
    for streamed_reply in generate_response(conversation):
        # Update history with the new streamed reply (note: only the last bot reply is updating)
        yield history + [(user_input, streamed_reply)]

# Build the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Chat with SmallDoge/Doge-20M-checkpoint")
    chatbot = gr.Chatbot()  # displays the conversation as a list of (user, assistant) pairs
    with gr.Row():
        msg = gr.Textbox(show_label=False, placeholder="Type your message here...")        
        clear = gr.Button("Clear")
    
    # When the user submits a message, first update the chat history with an empty reply…
    def user(message, history):
        return "", history + [(message, "")]
    
    # ...then stream the model response using our chat() generator
    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False) \
       .then(chat, [msg, chatbot], chatbot)

    clear.click(lambda: None, None, chatbot, queue=False)

# Enable queue for streaming responses and launch the app
demo.queue()
demo.launch()