File size: 3,586 Bytes
1d58561
da278a5
 
 
 
a1b31ed
 
 
 
 
 
 
da278a5
 
 
 
 
 
 
 
a1b31ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8e37ed
da278a5
a1b31ed
 
 
da278a5
a1b31ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d58561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1b31ed
1d58561
 
 
 
 
 
a1b31ed
1d58561
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import eventlet
eventlet.monkey_patch(socket=True, select=True)

import eventlet.wsgi

from flask import Flask, render_template
from flask_socketio import SocketIO
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
import torch
import time

app = Flask(__name__)
socketio = SocketIO(
    app,
    ping_timeout=60,
    async_mode='eventlet',
    cors_allowed_origins="*",
    logger=True,
    engineio_logger=True
)
# Initialize model and tokenizer
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto"
)


class WebSocketBeamStreamer(MultiBeamTextStreamer):
    """Custom streamer that sends updates through websockets with adjustable speed"""

    def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True):
        super().__init__(
            tokenizer,
            num_beams=num_beams,
            skip_prompt=skip_prompt,
            on_beam_update=self.on_beam_update,
            on_beam_finished=self.on_beam_finished
        )
        self.beam_texts = {i: "" for i in range(num_beams)}
        self.sleep_time = sleep_time  # Sleep time in milliseconds

    def on_beam_update(self, beam_idx: int, new_text: str):
        """Send beam updates through websocket with delay"""
        self.beam_texts[beam_idx] = new_text
        if self.sleep_time > 0:
            eventlet.sleep(self.sleep_time / 1000)  # Convert milliseconds to seconds
        # Force immediate emit and wait for confirmation
        socketio.emit('beam_update', {
            'beam_idx': beam_idx,
            'text': new_text
        }, callback=lambda: eventlet.sleep(0))

    def on_beam_finished(self, final_text: str):
        """Send completion notification through websocket"""
        socketio.emit('beam_finished', {
            'text': final_text
        })


@app.route('/')
def index():
    return render_template('index.html')


@socketio.on('generate')
def handle_generation(data):
    # Emit a generation start event
    socketio.emit('generation_started')

    prompt = data['prompt']
    num_beams = data.get('num_beams', 5)
    max_new_tokens = data.get('max_tokens', 512)
    sleep_time = data.get('sleep_time', 0)  # Get sleep time from frontend

    # Create messages format
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]

    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Prepare inputs
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Initialize streamer with sleep time
    streamer = WebSocketBeamStreamer(
        tokenizer=tokenizer,
        num_beams=num_beams,
        sleep_time=sleep_time,
        skip_prompt=True
    )

    try:
        # Generate with beam search
        with torch.no_grad():
            model.generate(
                **model_inputs,
                num_beams=num_beams,
                num_return_sequences=num_beams,
                max_new_tokens=max_new_tokens,
                output_scores=True,
                return_dict_in_generate=True,
                early_stopping=True,
                streamer=streamer
            )
    except Exception as e:
        socketio.emit('generation_error', {'error': str(e)})
    finally:
        # Emit generation completed event
        socketio.emit('generation_completed')