LanguageBind commited on
Commit
aa27f76
1 Parent(s): 37a8186

Update opensora/serve/gradio_web_server.py

Browse files
Files changed (1) hide show
  1. opensora/serve/gradio_web_server.py +13 -11
opensora/serve/gradio_web_server.py CHANGED
@@ -24,24 +24,26 @@ from opensora.models.diffusion.latte.modeling_latte import LatteT2V
24
  from opensora.sample.pipeline_videogen import VideoGenPipeline
25
  from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION
26
 
 
27
 
28
- @torch.inference_mode()
29
  def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
30
  seed = int(randomize_seed_fn(seed, randomize_seed))
31
  set_env(seed)
32
  video_length = transformer_model.config.video_length if not force_images else 1
33
  height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
34
  num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
35
- videos = videogen_pipeline(prompt,
36
- video_length=video_length,
37
- height=height,
38
- width=width,
39
- num_inference_steps=sample_steps,
40
- guidance_scale=scale,
41
- enable_temporal_attentions=not force_images,
42
- num_images_per_prompt=1,
43
- mask_feature=True,
44
- ).video
 
45
 
46
  torch.cuda.empty_cache()
47
  videos = videos[0]
 
24
  from opensora.sample.pipeline_videogen import VideoGenPipeline
25
  from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION
26
 
27
+ import space
28
 
29
+ @spaces.GPU
30
  def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
31
  seed = int(randomize_seed_fn(seed, randomize_seed))
32
  set_env(seed)
33
  video_length = transformer_model.config.video_length if not force_images else 1
34
  height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
35
  num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
36
+ with torch.no_grad():
37
+ videos = videogen_pipeline(prompt,
38
+ video_length=video_length,
39
+ height=height,
40
+ width=width,
41
+ num_inference_steps=sample_steps,
42
+ guidance_scale=scale,
43
+ enable_temporal_attentions=not force_images,
44
+ num_images_per_prompt=1,
45
+ mask_feature=True,
46
+ ).video
47
 
48
  torch.cuda.empty_cache()
49
  videos = videos[0]