VictorKai1996NUS commited on
Commit
61bd18b
1 Parent(s): 795c1b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -100
app.py CHANGED
@@ -1,4 +1,7 @@
1
  import os
 
 
 
2
  import torch
3
  from openai import OpenAI
4
  from time import time
@@ -10,10 +13,6 @@ from videosys import CogVideoConfig, VideoSysEngine
10
  from videosys.models.cogvideo.pipeline import CogVideoPABConfig
11
  import psutil
12
  import GPUtil
13
- import queue
14
- import threading
15
-
16
- os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
17
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
@@ -32,15 +31,6 @@ Other times the user will not want modifications , but instead want a new image
32
  Video descriptions must have the same num of words as examples below. Extra words will be ignored.
33
  """
34
 
35
- # 创建一个全局任务队列
36
- task_queue = queue.Queue()
37
-
38
- # 创建一个锁来保护共享资源
39
- lock = threading.Lock()
40
-
41
- # 创建一个列表来存储所有任务的状态
42
- tasks = []
43
-
44
  def convert_prompt(prompt: str, retry_times: int = 3) -> str:
45
  if not os.environ.get("OPENAI_API_KEY"):
46
  return prompt
@@ -111,68 +101,30 @@ def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
111
  logger.error(f"An error occurred: {str(e)}")
112
  return None
113
 
 
114
  def get_server_status():
115
  cpu_percent = psutil.cpu_percent()
116
  memory = psutil.virtual_memory()
117
  disk = psutil.disk_usage('/')
118
- try:
119
- gpus = GPUtil.getGPUs()
120
- if gpus:
121
- gpu = gpus[0] # 只获取第一个GPU的信息
122
- gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
123
- else:
124
- gpu_memory = "No GPU found"
125
- except:
126
- gpu_memory = "GPU information unavailable"
 
127
 
128
  return {
129
  'cpu': f"{cpu_percent}%",
130
  'memory': f"{memory.percent}%",
131
  'disk': f"{disk.percent}%",
132
- 'gpu_memory': gpu_memory
133
  }
134
 
135
- def task_processor():
136
- while True:
137
- task = task_queue.get()
138
- if task is None:
139
- break
140
-
141
- # 更新任务状态为"运行中"
142
- with lock:
143
- task['status'] = 'running'
144
-
145
- # 执行任务
146
- result = task['function'](*task['args'])
147
-
148
- # 更新任务状态为"完成"
149
- with lock:
150
- task['status'] = 'completed'
151
- task['result'] = result
152
-
153
- task_queue.task_done()
154
-
155
- # 启动任务处理器线程
156
- processor_thread = threading.Thread(target=task_processor)
157
- processor_thread.start()
158
-
159
- def add_task(function, args, task_name):
160
- task = {
161
- 'id': len(tasks),
162
- 'name': task_name,
163
- 'status': 'waiting',
164
- 'function': function,
165
- 'args': args,
166
- 'result': None
167
- }
168
- with lock:
169
- tasks.append(task)
170
- task_queue.put(task)
171
- return task['id']
172
 
173
- def get_task_status():
174
- with lock:
175
- return [{'id': task['id'], 'name': task['name'], 'status': task['status']} for task in tasks]
176
 
177
  css = """
178
  body {
@@ -301,27 +253,56 @@ with gr.Blocks(css=css) as demo:
301
  download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
302
  elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
303
 
304
- with gr.Row():
305
- task_status = gr.Dataframe(
306
- headers=["ID", "Task", "Status"],
307
- label="Task Queue",
308
- interactive=False
309
- )
310
- refresh_tasks_button = gr.Button("Refresh Tasks")
311
 
312
  def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
313
- task_id = add_task(generate, (load_model(), prompt, num_inference_steps, guidance_scale), f"Generate: {prompt[:20]}...")
314
- return get_task_status()
 
 
 
 
 
 
315
 
316
  def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
317
  threshold = [int(i) for i in threshold.split(",")]
318
  gap = int(gap)
319
- task_id = add_task(generate, (load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap), prompt, num_inference_steps, guidance_scale), f"Generate VS: {prompt[:20]}...")
320
- return get_task_status()
 
 
 
 
 
 
321
 
322
  def enhance_prompt_func(prompt):
323
  return convert_prompt(prompt, retry_times=1)
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  def update_server_status():
326
  status = get_server_status()
327
  return (
@@ -331,44 +312,25 @@ with gr.Blocks(css=css) as demo:
331
  status['gpu_memory']
332
  )
333
 
334
- def update_task_status():
335
- return get_task_status()
336
 
337
  generate_button.click(
338
  generate_vanilla,
339
  inputs=[prompt, num_inference_steps, guidance_scale],
340
- outputs=[task_status],
341
- concurrency_limit=1
342
  )
343
 
344
  generate_button_vs.click(
345
  generate_vs,
346
  inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
347
- outputs=[task_status],
348
- concurrency_limit=1
349
  )
350
 
351
- enhance_button.click(
352
- enhance_prompt_func,
353
- inputs=[prompt],
354
- outputs=[prompt],
355
- concurrency_limit=1
356
- )
357
 
358
- refresh_button.click(
359
- update_server_status,
360
- outputs=[cpu_status, memory_status, disk_status, gpu_status],
361
- concurrency_limit=1
362
- )
363
  demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
364
 
365
- refresh_tasks_button.click(
366
- update_task_status,
367
- outputs=[task_status],
368
- concurrency_limit=1
369
- )
370
- demo.load(update_task_status, outputs=[task_status], every=5) # 每5秒自动刷新一次
371
-
372
  if __name__ == "__main__":
373
- demo.queue(max_size=10)
374
- demo.launch(max_threads=4) # 设置最大线程数为4,您可以根据需要调整这个值
 
1
  import os
2
+
3
+ os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
4
+
5
  import torch
6
  from openai import OpenAI
7
  from time import time
 
13
  from videosys.models.cogvideo.pipeline import CogVideoPABConfig
14
  import psutil
15
  import GPUtil
 
 
 
 
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
 
31
  Video descriptions must have the same num of words as examples below. Extra words will be ignored.
32
  """
33
 
 
 
 
 
 
 
 
 
 
34
  def convert_prompt(prompt: str, retry_times: int = 3) -> str:
35
  if not os.environ.get("OPENAI_API_KEY"):
36
  return prompt
 
101
  logger.error(f"An error occurred: {str(e)}")
102
  return None
103
 
104
+
105
  def get_server_status():
106
  cpu_percent = psutil.cpu_percent()
107
  memory = psutil.virtual_memory()
108
  disk = psutil.disk_usage('/')
109
+ gpus = GPUtil.getGPUs()
110
+ gpu_info = []
111
+ for gpu in gpus:
112
+ gpu_info.append({
113
+ 'id': gpu.id,
114
+ 'name': gpu.name,
115
+ 'load': f"{gpu.load*100:.1f}%",
116
+ 'memory_used': f"{gpu.memoryUsed}MB",
117
+ 'memory_total': f"{gpu.memoryTotal}MB"
118
+ })
119
 
120
  return {
121
  'cpu': f"{cpu_percent}%",
122
  'memory': f"{memory.percent}%",
123
  'disk': f"{disk.percent}%",
124
+ 'gpu': gpu_info
125
  }
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
 
 
 
128
 
129
  css = """
130
  body {
 
253
  download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
254
  elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
255
 
256
+
257
+
 
 
 
 
 
258
 
259
  def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
260
+ engine = load_model()
261
+ t = time()
262
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
263
+ elapsed_time = time() - t
264
+ video_update = gr.update(visible=True, value=video_path)
265
+ elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
266
+
267
+ return video_path, video_update, elapsed_time
268
 
269
  def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
270
  threshold = [int(i) for i in threshold.split(",")]
271
  gap = int(gap)
272
+ engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
273
+ t = time()
274
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
275
+ elapsed_time = time() - t
276
+ video_update = gr.update(visible=True, value=video_path)
277
+ elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
278
+
279
+ return video_path, video_update, elapsed_time
280
 
281
  def enhance_prompt_func(prompt):
282
  return convert_prompt(prompt, retry_times=1)
283
 
284
+ def get_server_status():
285
+ cpu_percent = psutil.cpu_percent()
286
+ memory = psutil.virtual_memory()
287
+ disk = psutil.disk_usage('/')
288
+ try:
289
+ gpus = GPUtil.getGPUs()
290
+ if gpus:
291
+ gpu = gpus[0] # 只获取第一个GPU的信息
292
+ gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
293
+ else:
294
+ gpu_memory = "No GPU found"
295
+ except:
296
+ gpu_memory = "GPU information unavailable"
297
+
298
+ return {
299
+ 'cpu': f"{cpu_percent}%",
300
+ 'memory': f"{memory.percent}%",
301
+ 'disk': f"{disk.percent}%",
302
+ 'gpu_memory': gpu_memory
303
+ }
304
+
305
+
306
  def update_server_status():
307
  status = get_server_status()
308
  return (
 
312
  status['gpu_memory']
313
  )
314
 
 
 
315
 
316
  generate_button.click(
317
  generate_vanilla,
318
  inputs=[prompt, num_inference_steps, guidance_scale],
319
+ outputs=[video_output, download_video_button, elapsed_time],
 
320
  )
321
 
322
  generate_button_vs.click(
323
  generate_vs,
324
  inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
325
+ outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
 
326
  )
327
 
328
+ enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
 
 
 
 
 
329
 
330
+
331
+ refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
 
 
 
332
  demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
333
 
 
 
 
 
 
 
 
334
  if __name__ == "__main__":
335
+ demo.queue(max_size=10, default_concurrency_limit=1)
336
+ demo.launch()