import eventlet eventlet.monkey_patch() from flask import Flask, render_template from flask_socketio import SocketIO from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM import torch import time app = Flask(__name__) socketio = SocketIO(app, ping_timeout=60) # Initialize model and tokenizer MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype="auto", device_map="auto" ) class WebSocketBeamStreamer(MultiBeamTextStreamer): """Custom streamer that sends updates through websockets with adjustable speed""" def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True): super().__init__( tokenizer, num_beams=num_beams, skip_prompt=skip_prompt, on_beam_update=self.on_beam_update, on_beam_finished=self.on_beam_finished ) self.beam_texts = {i: "" for i in range(num_beams)} self.sleep_time = sleep_time # Sleep time in milliseconds def on_beam_update(self, beam_idx: int, new_text: str): """Send beam updates through websocket with delay""" self.beam_texts[beam_idx] = new_text if self.sleep_time > 0: eventlet.sleep(self.sleep_time / 1000) # Convert milliseconds to seconds socketio.emit('beam_update', { 'beam_idx': beam_idx, 'text': new_text }) def on_beam_finished(self, final_text: str): """Send completion notification through websocket""" socketio.emit('beam_finished', { 'text': final_text }) @app.route('/') def index(): return render_template('index.html') @socketio.on('generate') def handle_generation(data): # Emit a generation start event socketio.emit('generation_started') prompt = data['prompt'] num_beams = data.get('num_beams', 5) max_new_tokens = data.get('max_tokens', 512) sleep_time = data.get('sleep_time', 0) # Get sleep time from frontend # Create messages format messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] # Apply chat template text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Prepare inputs model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # Initialize streamer with sleep time streamer = WebSocketBeamStreamer( tokenizer=tokenizer, num_beams=num_beams, sleep_time=sleep_time, skip_prompt=True ) try: # Generate with beam search with torch.no_grad(): model.generate( **model_inputs, num_beams=num_beams, num_return_sequences=num_beams, max_new_tokens=max_new_tokens, output_scores=True, return_dict_in_generate=True, early_stopping=True, streamer=streamer ) except Exception as e: socketio.emit('generation_error', {'error': str(e)}) finally: # Emit generation completed event socketio.emit('generation_completed') if __name__ == '__main__': socketio.run(app, host='0.0.0.0', port=7860)