1inkusFace commited on
Commit
a9fb385
·
verified ·
1 Parent(s): 72a5e4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -80,22 +80,22 @@ negative_prompt = "Aerial view, aerial view, overexposed, low quality, deformati
80
 
81
  @spaces.GPU(duration=90)
82
  def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ):
83
-
84
- random.seed(time.time())
85
- seed = int(random.randrange(4294967294))
86
  if segment==1:
 
 
87
  #Offload.offload(
88
  # pipeline=pipe,
89
  # config=offload_config,
90
  #)
91
  pipe.text_encoder.to("cuda")
92
  pipe.text_encoder_2.to("cuda")
93
- with torch.no_grad():
94
- prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_attention_mask, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
95
- prompt=prompt, do_classifier_free_guidance=True, negative_prompt=negative_prompt, device=device
96
- )
97
  pipe.text_encoder.to("cpu")
98
  pipe.text_encoder_2.to("cpu")
 
99
  transformer_dtype = pipe.transformer.dtype
100
  prompt_embeds = prompt_embeds.to(transformer_dtype)
101
  prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
 
80
 
81
  @spaces.GPU(duration=90)
82
  def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ):
 
 
 
83
  if segment==1:
84
+ random.seed(time.time())
85
+ seed = int(random.randrange(4294967294))
86
  #Offload.offload(
87
  # pipeline=pipe,
88
  # config=offload_config,
89
  #)
90
  pipe.text_encoder.to("cuda")
91
  pipe.text_encoder_2.to("cuda")
92
+ #with torch.no_grad():
93
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_attention_mask, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
94
+ prompt=prompt, do_classifier_free_guidance=True, negative_prompt=negative_prompt, device=device
95
+ )
96
  pipe.text_encoder.to("cpu")
97
  pipe.text_encoder_2.to("cpu")
98
+ generator = torch.Generator(device='cuda').manual_seed(seed)
99
  transformer_dtype = pipe.transformer.dtype
100
  prompt_embeds = prompt_embeds.to(transformer_dtype)
101
  prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)