Moshe Ofer commited on
Commit
1d58561
·
1 Parent(s): 3df6a65

Initial commit for Hugging Face Space

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/multi_beam_text_streamer.iml" filepath="$PROJECT_DIR$/.idea/multi_beam_text_streamer.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/multi_beam_text_streamer.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
Dockerfile CHANGED
@@ -26,9 +26,5 @@ CMD ["gunicorn", \
26
  "--worker-class", "eventlet", \
27
  "--workers", "1", \
28
  "--timeout", "300", \
29
- "--keep-alive", "120", \
30
- "--log-level", "debug", \
31
- "--worker-connections", "1000", \
32
- "--backlog", "2048", \
33
  "--bind", "0.0.0.0:7860", \
34
  "app:app"]
 
26
  "--worker-class", "eventlet", \
27
  "--workers", "1", \
28
  "--timeout", "300", \
 
 
 
 
29
  "--bind", "0.0.0.0:7860", \
30
  "app:app"]
app.py CHANGED
@@ -1,21 +1,13 @@
 
 
1
  from flask import Flask, render_template
2
  from flask_socketio import SocketIO
3
  from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import time
6
- import eventlet
7
 
8
- eventlet.monkey_patch()
9
  app = Flask(__name__)
10
- socketio = SocketIO(
11
- app,
12
- ping_timeout=60,
13
- ping_interval=25,
14
- cors_allowed_origins="*",
15
- async_mode='eventlet',
16
- logger=True,
17
- engineio_logger=True
18
- )
19
 
20
  # Initialize model and tokenizer
21
  MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
@@ -65,61 +57,57 @@ def index():
65
 
66
  @socketio.on('generate')
67
  def handle_generation(data):
68
- def generate_async():
69
- try:
70
- app.logger.info("Generation started with data: %s", data)
71
- socketio.emit('generation_started', callback=lambda: eventlet.sleep(0))
72
-
73
- prompt = data['prompt']
74
- num_beams = data.get('num_beams', 5)
75
- max_new_tokens = data.get('max_tokens', 512)
76
- sleep_time = data.get('sleep_time', 0)
77
-
78
- messages = [
79
- {"role": "system", "content": "You are a helpful assistant."},
80
- {"role": "user", "content": prompt}
81
- ]
82
-
83
- text = tokenizer.apply_chat_template(
84
- messages,
85
- tokenize=False,
86
- add_generation_prompt=True
87
- )
88
-
89
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
90
-
91
- streamer = WebSocketBeamStreamer(
92
- tokenizer=tokenizer,
 
 
 
 
 
 
 
 
 
 
 
 
93
  num_beams=num_beams,
94
- sleep_time=sleep_time,
95
- skip_prompt=True
 
 
 
 
96
  )
 
 
 
 
 
97
 
98
- with torch.no_grad():
99
- model.generate(
100
- **model_inputs,
101
- num_beams=num_beams,
102
- num_return_sequences=num_beams,
103
- max_new_tokens=max_new_tokens,
104
- output_scores=True,
105
- return_dict_in_generate=True,
106
- early_stopping=True,
107
- streamer=streamer
108
- )
109
-
110
- except Exception as e:
111
- app.logger.error("Generation error: %s", str(e), exc_info=True)
112
- socketio.emit('generation_error', {'error': str(e)})
113
- finally:
114
- socketio.emit('generation_completed')
115
-
116
- eventlet.spawn(generate_async)
117
 
118
  if __name__ == '__main__':
119
- socketio.run(
120
- app,
121
- host='0.0.0.0',
122
- port=7860,
123
- debug=True,
124
- use_reloader=False
125
- )
 
1
+ import eventlet
2
+ eventlet.monkey_patch()
3
  from flask import Flask, render_template
4
  from flask_socketio import SocketIO
5
  from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
  import time
 
8
 
 
9
  app = Flask(__name__)
10
+ socketio = SocketIO(app, ping_timeout=60)
 
 
 
 
 
 
 
 
11
 
12
  # Initialize model and tokenizer
13
  MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
 
57
 
58
  @socketio.on('generate')
59
  def handle_generation(data):
60
+ # Emit a generation start event
61
+ socketio.emit('generation_started')
62
+
63
+ prompt = data['prompt']
64
+ num_beams = data.get('num_beams', 5)
65
+ max_new_tokens = data.get('max_tokens', 512)
66
+ sleep_time = data.get('sleep_time', 0) # Get sleep time from frontend
67
+
68
+ # Create messages format
69
+ messages = [
70
+ {"role": "system", "content": "You are a helpful assistant."},
71
+ {"role": "user", "content": prompt}
72
+ ]
73
+
74
+ # Apply chat template
75
+ text = tokenizer.apply_chat_template(
76
+ messages,
77
+ tokenize=False,
78
+ add_generation_prompt=True
79
+ )
80
+
81
+ # Prepare inputs
82
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
83
+
84
+ # Initialize streamer with sleep time
85
+ streamer = WebSocketBeamStreamer(
86
+ tokenizer=tokenizer,
87
+ num_beams=num_beams,
88
+ sleep_time=sleep_time,
89
+ skip_prompt=True
90
+ )
91
+
92
+ try:
93
+ # Generate with beam search
94
+ with torch.no_grad():
95
+ model.generate(
96
+ **model_inputs,
97
  num_beams=num_beams,
98
+ num_return_sequences=num_beams,
99
+ max_new_tokens=max_new_tokens,
100
+ output_scores=True,
101
+ return_dict_in_generate=True,
102
+ early_stopping=True,
103
+ streamer=streamer
104
  )
105
+ except Exception as e:
106
+ socketio.emit('generation_error', {'error': str(e)})
107
+ finally:
108
+ # Emit generation completed event
109
+ socketio.emit('generation_completed')
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  if __name__ == '__main__':
113
+ socketio.run(app, host='0.0.0.0', port=7860)