Spaces:
Sleeping
Sleeping
File size: 3,586 Bytes
1d58561 da278a5 a1b31ed da278a5 a1b31ed b8e37ed da278a5 a1b31ed da278a5 a1b31ed 1d58561 a1b31ed 1d58561 a1b31ed 1d58561 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import eventlet
eventlet.monkey_patch(socket=True, select=True)
import eventlet.wsgi
from flask import Flask, render_template
from flask_socketio import SocketIO
from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
import torch
import time
app = Flask(__name__)
socketio = SocketIO(
app,
ping_timeout=60,
async_mode='eventlet',
cors_allowed_origins="*",
logger=True,
engineio_logger=True
)
# Initialize model and tokenizer
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto"
)
class WebSocketBeamStreamer(MultiBeamTextStreamer):
"""Custom streamer that sends updates through websockets with adjustable speed"""
def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True):
super().__init__(
tokenizer,
num_beams=num_beams,
skip_prompt=skip_prompt,
on_beam_update=self.on_beam_update,
on_beam_finished=self.on_beam_finished
)
self.beam_texts = {i: "" for i in range(num_beams)}
self.sleep_time = sleep_time # Sleep time in milliseconds
def on_beam_update(self, beam_idx: int, new_text: str):
"""Send beam updates through websocket with delay"""
self.beam_texts[beam_idx] = new_text
if self.sleep_time > 0:
eventlet.sleep(self.sleep_time / 1000) # Convert milliseconds to seconds
# Force immediate emit and wait for confirmation
socketio.emit('beam_update', {
'beam_idx': beam_idx,
'text': new_text
}, callback=lambda: eventlet.sleep(0))
def on_beam_finished(self, final_text: str):
"""Send completion notification through websocket"""
socketio.emit('beam_finished', {
'text': final_text
})
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('generate')
def handle_generation(data):
# Emit a generation start event
socketio.emit('generation_started')
prompt = data['prompt']
num_beams = data.get('num_beams', 5)
max_new_tokens = data.get('max_tokens', 512)
sleep_time = data.get('sleep_time', 0) # Get sleep time from frontend
# Create messages format
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Prepare inputs
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Initialize streamer with sleep time
streamer = WebSocketBeamStreamer(
tokenizer=tokenizer,
num_beams=num_beams,
sleep_time=sleep_time,
skip_prompt=True
)
try:
# Generate with beam search
with torch.no_grad():
model.generate(
**model_inputs,
num_beams=num_beams,
num_return_sequences=num_beams,
max_new_tokens=max_new_tokens,
output_scores=True,
return_dict_in_generate=True,
early_stopping=True,
streamer=streamer
)
except Exception as e:
socketio.emit('generation_error', {'error': str(e)})
finally:
# Emit generation completed event
socketio.emit('generation_completed')
|