lisonallen commited on
Commit
6dcbefe
·
1 Parent(s): cdbfba8

优化模型加载过程和异常处理,确保视频生成完整性

Browse files
Files changed (1) hide show
  1. app.py +206 -93
app.py CHANGED
@@ -149,13 +149,33 @@ def get_models():
149
  """获取模型,如果尚未加载则加载模型"""
150
  global models
151
 
 
 
 
152
  if not models:
153
- if IN_HF_SPACE and 'spaces' in globals():
154
- print("使用@spaces.GPU装饰器加载模型")
155
- models = initialize_models()
156
- else:
157
- print("直接加载模型")
158
- load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  return models
161
 
@@ -180,6 +200,10 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
180
  total_latent_sections = int(max(round(total_latent_sections), 1))
181
 
182
  job_id = generate_timestamp()
 
 
 
 
183
 
184
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
185
 
@@ -273,6 +297,15 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
273
  latent_padding_size = latent_padding * latent_window_size
274
 
275
  if stream.input_queue.top() == 'end':
 
 
 
 
 
 
 
 
 
276
  stream.output_queue.push(('end', None))
277
  return
278
 
@@ -313,36 +346,47 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
313
  stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
314
  return
315
 
316
- generated_latents = sample_hunyuan(
317
- transformer=transformer,
318
- sampler='unipc',
319
- width=width,
320
- height=height,
321
- frames=num_frames,
322
- real_guidance_scale=cfg,
323
- distilled_guidance_scale=gs,
324
- guidance_rescale=rs,
325
- # shift=3.0,
326
- num_inference_steps=steps,
327
- generator=rnd,
328
- prompt_embeds=llama_vec,
329
- prompt_embeds_mask=llama_attention_mask,
330
- prompt_poolers=clip_l_pooler,
331
- negative_prompt_embeds=llama_vec_n,
332
- negative_prompt_embeds_mask=llama_attention_mask_n,
333
- negative_prompt_poolers=clip_l_pooler_n,
334
- device=gpu,
335
- dtype=torch.bfloat16,
336
- image_embeddings=image_encoder_last_hidden_state,
337
- latent_indices=latent_indices,
338
- clean_latents=clean_latents,
339
- clean_latent_indices=clean_latent_indices,
340
- clean_latents_2x=clean_latents_2x,
341
- clean_latent_2x_indices=clean_latent_2x_indices,
342
- clean_latents_4x=clean_latents_4x,
343
- clean_latent_4x_indices=clean_latent_4x_indices,
344
- callback=callback,
345
- )
 
 
 
 
 
 
 
 
 
 
 
346
 
347
  if is_last_section:
348
  generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
@@ -356,36 +400,57 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
356
 
357
  real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
358
 
359
- if history_pixels is None:
360
- history_pixels = vae_decode(real_history_latents, vae).cpu()
361
- else:
362
- section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
363
- overlapped_frames = latent_window_size * 4 - 3
 
364
 
365
- current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
366
- history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
367
 
368
- if not high_vram:
369
- unload_complete_models()
370
 
371
- output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
372
 
373
- save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
374
 
375
- print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
376
 
377
- stream.output_queue.push(('file', output_filename))
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  if is_last_section:
380
  break
381
- except:
 
382
  traceback.print_exc()
383
 
384
  if not high_vram:
385
- unload_complete_models(
386
- text_encoder, text_encoder_2, image_encoder, vae, transformer
387
- )
 
 
 
 
 
 
 
388
 
 
389
  stream.output_queue.push(('end', None))
390
  return
391
 
@@ -397,28 +462,52 @@ if IN_HF_SPACE and 'spaces' in globals():
397
  global stream
398
  assert input_image is not None, 'No input image!'
399
 
 
400
  yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
401
 
402
- stream = AsyncStream()
403
-
404
- async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache)
405
-
406
- output_filename = None
407
-
408
- while True:
409
- flag, data = stream.output_queue.next()
410
-
411
- if flag == 'file':
412
- output_filename = data
413
- yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
414
-
415
- if flag == 'progress':
416
- preview, desc, html = data
417
- yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
418
-
419
- if flag == 'end':
420
- yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
421
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  process = process_with_gpu
424
  else:
@@ -426,28 +515,52 @@ else:
426
  global stream
427
  assert input_image is not None, 'No input image!'
428
 
 
429
  yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
430
 
431
- stream = AsyncStream()
432
-
433
- async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache)
434
-
435
- output_filename = None
436
-
437
- while True:
438
- flag, data = stream.output_queue.next()
439
-
440
- if flag == 'file':
441
- output_filename = data
442
- yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
443
-
444
- if flag == 'progress':
445
- preview, desc, html = data
446
- yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
447
-
448
- if flag == 'end':
449
- yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
450
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
 
453
  def end_process():
 
149
  """获取模型,如果尚未加载则加载模型"""
150
  global models
151
 
152
+ # 添加模型加载锁,防止并发加载
153
+ model_loading_key = "__model_loading__"
154
+
155
  if not models:
156
+ # 检查是否正在加载模型
157
+ if model_loading_key in globals():
158
+ print("模型正在加载中,等待...")
159
+ # 等待模型加载完成
160
+ import time
161
+ while not models and model_loading_key in globals():
162
+ time.sleep(0.5)
163
+ return models
164
+
165
+ try:
166
+ # 设置加载标记
167
+ globals()[model_loading_key] = True
168
+
169
+ if IN_HF_SPACE and 'spaces' in globals():
170
+ print("使用@spaces.GPU装饰器加载模型")
171
+ models = initialize_models()
172
+ else:
173
+ print("直接加载模型")
174
+ load_models()
175
+ finally:
176
+ # 无论成功与否,都移除加载标记
177
+ if model_loading_key in globals():
178
+ del globals()[model_loading_key]
179
 
180
  return models
181
 
 
200
  total_latent_sections = int(max(round(total_latent_sections), 1))
201
 
202
  job_id = generate_timestamp()
203
+ last_output_filename = None
204
+ history_pixels = None
205
+ history_latents = None
206
+ total_generated_latent_frames = 0
207
 
208
  stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
209
 
 
297
  latent_padding_size = latent_padding * latent_window_size
298
 
299
  if stream.input_queue.top() == 'end':
300
+ # 确保在结束时保存当前的视频
301
+ if history_pixels is not None and total_generated_latent_frames > 0:
302
+ try:
303
+ output_filename = os.path.join(outputs_folder, f'{job_id}_final_{total_generated_latent_frames}.mp4')
304
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
305
+ stream.output_queue.push(('file', output_filename))
306
+ except Exception as e:
307
+ print(f"保存最终视频时出错: {e}")
308
+
309
  stream.output_queue.push(('end', None))
310
  return
311
 
 
346
  stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
347
  return
348
 
349
+ try:
350
+ generated_latents = sample_hunyuan(
351
+ transformer=transformer,
352
+ sampler='unipc',
353
+ width=width,
354
+ height=height,
355
+ frames=num_frames,
356
+ real_guidance_scale=cfg,
357
+ distilled_guidance_scale=gs,
358
+ guidance_rescale=rs,
359
+ # shift=3.0,
360
+ num_inference_steps=steps,
361
+ generator=rnd,
362
+ prompt_embeds=llama_vec,
363
+ prompt_embeds_mask=llama_attention_mask,
364
+ prompt_poolers=clip_l_pooler,
365
+ negative_prompt_embeds=llama_vec_n,
366
+ negative_prompt_embeds_mask=llama_attention_mask_n,
367
+ negative_prompt_poolers=clip_l_pooler_n,
368
+ device=gpu,
369
+ dtype=torch.bfloat16,
370
+ image_embeddings=image_encoder_last_hidden_state,
371
+ latent_indices=latent_indices,
372
+ clean_latents=clean_latents,
373
+ clean_latent_indices=clean_latent_indices,
374
+ clean_latents_2x=clean_latents_2x,
375
+ clean_latent_2x_indices=clean_latent_2x_indices,
376
+ clean_latents_4x=clean_latents_4x,
377
+ clean_latent_4x_indices=clean_latent_4x_indices,
378
+ callback=callback,
379
+ )
380
+ except Exception as e:
381
+ print(f"采样过程中出错: {e}")
382
+ traceback.print_exc()
383
+
384
+ # 如果已经有生成的视频,返回最后生成的视频
385
+ if last_output_filename:
386
+ stream.output_queue.push(('file', last_output_filename))
387
+
388
+ stream.output_queue.push(('end', None))
389
+ return
390
 
391
  if is_last_section:
392
  generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
 
400
 
401
  real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
402
 
403
+ try:
404
+ if history_pixels is None:
405
+ history_pixels = vae_decode(real_history_latents, vae).cpu()
406
+ else:
407
+ section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
408
+ overlapped_frames = latent_window_size * 4 - 3
409
 
410
+ current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
411
+ history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
412
 
413
+ if not high_vram:
414
+ unload_complete_models()
415
 
416
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
417
 
418
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
419
 
420
+ print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
421
 
422
+ last_output_filename = output_filename
423
+ stream.output_queue.push(('file', output_filename))
424
+ except Exception as e:
425
+ print(f"视频解码或保存过程中出错: {e}")
426
+ traceback.print_exc()
427
+
428
+ # 如果已经有生成的视频,返回最后生成的视频
429
+ if last_output_filename:
430
+ stream.output_queue.push(('file', last_output_filename))
431
+
432
+ # 尝试继续下一次迭代
433
+ continue
434
 
435
  if is_last_section:
436
  break
437
+ except Exception as e:
438
+ print(f"处理过程中出现错误: {e}")
439
  traceback.print_exc()
440
 
441
  if not high_vram:
442
+ try:
443
+ unload_complete_models(
444
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
445
+ )
446
+ except Exception:
447
+ pass
448
+
449
+ # 如果已经有生成的视频,返回最后生成的视频
450
+ if last_output_filename:
451
+ stream.output_queue.push(('file', last_output_filename))
452
 
453
+ # 确保总是返回end信号
454
  stream.output_queue.push(('end', None))
455
  return
456
 
 
462
  global stream
463
  assert input_image is not None, 'No input image!'
464
 
465
+ # 初始化UI状态
466
  yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
467
 
468
+ try:
469
+ stream = AsyncStream()
470
+
471
+ # 异步启动worker
472
+ async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache)
473
+
474
+ output_filename = None
475
+ prev_output_filename = None
476
+
477
+ # 持续检查worker的输出
478
+ while True:
479
+ try:
480
+ flag, data = stream.output_queue.next()
481
+
482
+ if flag == 'file':
483
+ output_filename = data
484
+ prev_output_filename = output_filename
485
+ yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
486
+
487
+ if flag == 'progress':
488
+ preview, desc, html = data
489
+ yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
490
+
491
+ if flag == 'end':
492
+ # 如果有最后的视频文件,确保返回
493
+ if output_filename is None and prev_output_filename is not None:
494
+ output_filename = prev_output_filename
495
+
496
+ yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
497
+ break
498
+ except Exception as e:
499
+ print(f"处理输出时出错: {e}")
500
+ # 如果有最后的视频文件,确保返回
501
+ if prev_output_filename is not None:
502
+ yield prev_output_filename, gr.update(visible=False), gr.update(), f'处理过程中出现错误,但已生成部分视频', gr.update(interactive=True), gr.update(interactive=False)
503
+ else:
504
+ yield None, gr.update(visible=False), gr.update(), f'处理过程中出现错误: {str(e)}', gr.update(interactive=True), gr.update(interactive=False)
505
+ break
506
+
507
+ except Exception as e:
508
+ print(f"启动处理时出错: {e}")
509
+ traceback.print_exc()
510
+ yield None, gr.update(), gr.update(), f'启动处理时出错: {str(e)}', gr.update(interactive=True), gr.update(interactive=False)
511
 
512
  process = process_with_gpu
513
  else:
 
515
  global stream
516
  assert input_image is not None, 'No input image!'
517
 
518
+ # 初始化UI状态
519
  yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
520
 
521
+ try:
522
+ stream = AsyncStream()
523
+
524
+ # 异步启动worker
525
+ async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache)
526
+
527
+ output_filename = None
528
+ prev_output_filename = None
529
+
530
+ # 持续检查worker的输出
531
+ while True:
532
+ try:
533
+ flag, data = stream.output_queue.next()
534
+
535
+ if flag == 'file':
536
+ output_filename = data
537
+ prev_output_filename = output_filename
538
+ yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
539
+
540
+ if flag == 'progress':
541
+ preview, desc, html = data
542
+ yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
543
+
544
+ if flag == 'end':
545
+ # 如果有最后的视频文件,确保返回
546
+ if output_filename is None and prev_output_filename is not None:
547
+ output_filename = prev_output_filename
548
+
549
+ yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
550
+ break
551
+ except Exception as e:
552
+ print(f"处理输出时出错: {e}")
553
+ # 如果有最后的视频文件,确保返回
554
+ if prev_output_filename is not None:
555
+ yield prev_output_filename, gr.update(visible=False), gr.update(), f'处理过程中出现错误,但已生成部分视频', gr.update(interactive=True), gr.update(interactive=False)
556
+ else:
557
+ yield None, gr.update(visible=False), gr.update(), f'处理过程中出现错误: {str(e)}', gr.update(interactive=True), gr.update(interactive=False)
558
+ break
559
+
560
+ except Exception as e:
561
+ print(f"启动处理时出错: {e}")
562
+ traceback.print_exc()
563
+ yield None, gr.update(), gr.update(), f'启动处理时出错: {str(e)}', gr.update(interactive=True), gr.update(interactive=False)
564
 
565
 
566
  def end_process():