|
import eventlet |
|
eventlet.monkey_patch(socket=True, select=True) |
|
|
|
import eventlet.wsgi |
|
|
|
from flask import Flask, render_template |
|
from flask_socketio import SocketIO |
|
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import time |
|
|
|
app = Flask(__name__) |
|
socketio = SocketIO( |
|
app, |
|
ping_timeout=60, |
|
async_mode='eventlet', |
|
cors_allowed_origins="*", |
|
logger=True, |
|
engineio_logger=True |
|
) |
|
|
|
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
|
|
|
|
class WebSocketBeamStreamer(MultiBeamTextStreamer): |
|
"""Custom streamer that sends updates through websockets with adjustable speed""" |
|
|
|
def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True): |
|
super().__init__( |
|
tokenizer, |
|
num_beams=num_beams, |
|
skip_prompt=skip_prompt, |
|
on_beam_update=self.on_beam_update, |
|
on_beam_finished=self.on_beam_finished |
|
) |
|
self.beam_texts = {i: "" for i in range(num_beams)} |
|
self.sleep_time = sleep_time |
|
|
|
def on_beam_update(self, beam_idx: int, new_text: str): |
|
"""Send beam updates through websocket with delay""" |
|
self.beam_texts[beam_idx] = new_text |
|
if self.sleep_time > 0: |
|
eventlet.sleep(self.sleep_time / 1000) |
|
|
|
socketio.emit('beam_update', { |
|
'beam_idx': beam_idx, |
|
'text': new_text |
|
}, callback=lambda: eventlet.sleep(0)) |
|
|
|
def on_beam_finished(self, final_text: str): |
|
"""Send completion notification through websocket""" |
|
socketio.emit('beam_finished', { |
|
'text': final_text |
|
}) |
|
|
|
|
|
@app.route('/') |
|
def index(): |
|
return render_template('index.html') |
|
|
|
|
|
@socketio.on('generate') |
|
def handle_generation(data): |
|
|
|
socketio.emit('generation_started') |
|
|
|
prompt = data['prompt'] |
|
num_beams = data.get('num_beams', 5) |
|
max_new_tokens = data.get('max_tokens', 512) |
|
sleep_time = data.get('sleep_time', 0) |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
|
|
streamer = WebSocketBeamStreamer( |
|
tokenizer=tokenizer, |
|
num_beams=num_beams, |
|
sleep_time=sleep_time, |
|
skip_prompt=True |
|
) |
|
|
|
try: |
|
|
|
with torch.no_grad(): |
|
model.generate( |
|
**model_inputs, |
|
num_beams=num_beams, |
|
num_return_sequences=num_beams, |
|
max_new_tokens=max_new_tokens, |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
early_stopping=True, |
|
streamer=streamer |
|
) |
|
except Exception as e: |
|
socketio.emit('generation_error', {'error': str(e)}) |
|
finally: |
|
|
|
socketio.emit('generation_completed') |
|
|