Spaces:
Sleeping
Sleeping
from flask import Flask, render_template | |
from flask_socketio import SocketIO | |
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import time | |
import eventlet | |
eventlet.monkey_patch() | |
app = Flask(__name__) | |
socketio = SocketIO(app, ping_timeout=60) | |
# Initialize model and tokenizer | |
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
class WebSocketBeamStreamer(MultiBeamTextStreamer): | |
"""Custom streamer that sends updates through websockets with adjustable speed""" | |
def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True): | |
super().__init__( | |
tokenizer, | |
num_beams=num_beams, | |
skip_prompt=skip_prompt, | |
on_beam_update=self.on_beam_update, | |
on_beam_finished=self.on_beam_finished | |
) | |
self.beam_texts = {i: "" for i in range(num_beams)} | |
self.sleep_time = sleep_time # Sleep time in milliseconds | |
def on_beam_update(self, beam_idx: int, new_text: str): | |
"""Send beam updates through websocket with delay""" | |
self.beam_texts[beam_idx] = new_text | |
if self.sleep_time > 0: | |
time.sleep(self.sleep_time / 1000) # Convert milliseconds to seconds | |
socketio.emit('beam_update', { | |
'beam_idx': beam_idx, | |
'text': new_text | |
}) | |
def on_beam_finished(self, final_text: str): | |
"""Send completion notification through websocket""" | |
socketio.emit('beam_finished', { | |
'text': final_text | |
}) | |
def index(): | |
return render_template('index.html') | |
def handle_generation(data): | |
# Emit a generation start event | |
socketio.emit('generation_started') | |
prompt = data['prompt'] | |
num_beams = data.get('num_beams', 5) | |
max_new_tokens = data.get('max_tokens', 512) | |
sleep_time = data.get('sleep_time', 0) # Get sleep time from frontend | |
# Create messages format | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt} | |
] | |
# Apply chat template | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Prepare inputs | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
# Initialize streamer with sleep time | |
streamer = WebSocketBeamStreamer( | |
tokenizer=tokenizer, | |
num_beams=num_beams, | |
sleep_time=sleep_time, | |
skip_prompt=True | |
) | |
try: | |
# Generate with beam search | |
with torch.no_grad(): | |
model.generate( | |
**model_inputs, | |
num_beams=num_beams, | |
num_return_sequences=num_beams, | |
max_new_tokens=max_new_tokens, | |
output_scores=True, | |
return_dict_in_generate=True, | |
early_stopping=True, | |
streamer=streamer | |
) | |
except Exception as e: | |
socketio.emit('generation_error', {'error': str(e)}) | |
finally: | |
# Emit generation completed event | |
socketio.emit('generation_completed') | |
if __name__ == '__main__': | |
socketio.run(app, host='0.0.0.0', port=7860) |