1inkusFace commited on
Commit
6849eaf
·
verified ·
1 Parent(s): ed78ea9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -79
app.py CHANGED
@@ -13,7 +13,6 @@ from diffusers.utils import load_image
13
  from PIL import Image
14
 
15
  import torch
16
- from torchvision import transforms
17
 
18
  torch.backends.cuda.matmul.allow_tf32 = False
19
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
@@ -43,90 +42,325 @@ def init_predictor():
43
  compiler_transformer=False,
44
  )
45
  )
46
-
47
- @spaces.GPU(duration=120)
48
- def generate_video(prompt, image, size, steps, frames, guidance_scale, progress=gr.Progress(track_tqdm=True) ):
49
- print(f"image:{type(image)}")
50
  random.seed(time.time())
51
  seed = int(random.randrange(4294967294))
52
- img = load_image(image=image)
53
- img.resize((size,size), Image.LANCZOS)
54
- kwargs = {
55
- "image": img,
56
- "prompt": prompt,
57
- "height": size,
58
- "width": size,
59
- "num_frames": frames,
60
- "num_inference_steps": steps,
61
- "seed": seed,
62
- "guidance_scale": guidance_scale,
63
- "embedded_guidance_scale": 1.0,
64
- "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
65
- "cfg_for": False,
66
- }
67
- assert image is not None, "please input image"
68
-
69
- '''
70
- preprocess = transforms.Compose(
71
- [
72
- transforms.ToTensor(), # Converts PIL [0, 255] (H, W, C) to Tensor [0, 1] (C, H, W)
73
- transforms.Normalize([0.5], [0.5]), # Normalizes Tensor [0, 1] to [-1, 1]
74
- # Use [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] for RGB
75
- ]
76
- )
77
- image_tensor = preprocess(img)
78
- # 3. Add the batch dimension (B=1)
79
- # Resulting shape: [1, C, H, W]
80
- image_tensor = image_tensor.unsqueeze(0)
81
- kwargs["image"] = image_tensor
82
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- output = predictor.inference(kwargs)
85
- save_dir = f"./"
86
- video_out_file = f"{save_dir}/{seed}.mp4"
87
- print(f"generate video, local path: {video_out_file}")
88
- export_to_video(output, video_out_file, fps=24)
89
- return video_out_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
 
91
  with gr.Blocks() as demo:
92
- with gr.Row():
93
- image = gr.Image(label="Upload Image", type="filepath")
94
- prompt = gr.Textbox(label="Input Prompt")
95
- size = gr.Slider(
96
- label="Size",
97
- minimum=256,
98
- maximum=1024,
99
- step=16,
100
- value=368,
101
- )
102
- frames = gr.Slider(
103
- label="Number of Frames",
104
- minimum=16,
105
- maximum=256,
106
- step=8,
107
- value=64,
108
- )
109
- steps = gr.Slider(
110
- label="Number of Steps",
111
- minimum=1,
112
- maximum=96,
113
- step=1,
114
- value=25,
115
- )
116
- guidance_scale = gr.Slider(
117
- label="Guidance Scale",
118
- minimum=1.0,
119
- maximum=16.0,
120
- step=.1,
121
- value=6.0,
122
- )
123
- submit_button = gr.Button("Generate Video")
124
- output_video = gr.Video(label="Generated Video")
125
- submit_button.click(
126
- fn=generate_video,
127
- inputs=[prompt, image, size, steps, frames, guidance_scale],
128
- outputs=[output_video],
 
 
 
 
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  if __name__ == "__main__":
132
  init_predictor()
 
13
  from PIL import Image
14
 
15
  import torch
 
16
 
17
  torch.backends.cuda.matmul.allow_tf32 = False
18
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
 
42
  compiler_transformer=False,
43
  )
44
  )
45
+
46
+ @spaces.GPU(duration=60)
47
+ def generate_video(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ):
48
+
49
  random.seed(time.time())
50
  seed = int(random.randrange(4294967294))
51
+ if segment==1:
52
+
53
+
54
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
55
+ prompt=prompt, prompt_2=prompt, device=device
56
+ )
57
+ transformer_pooled_projections = pooled_prompt_embeds
58
+ transformer_pooled_projections = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
59
+ pipe.scheduler.set_timesteps(num_inference_steps, device=torch.device('cuda'))
60
+ timesteps = pipe.scheduler.timesteps
61
+ all_timesteps_cpu = timesteps.cpu()
62
+ timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
63
+ segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda")
64
+
65
+ num_channels_latents = pipe.transformer.config.in_channels
66
+ latents = pipe.prepare_latents(
67
+ batch_size=1, num_channels_latents=pipe.transformer.config.in_channels, height=height, width=width, num_frames=frames,
68
+ dtype=torch.float32, device=device, generator=generator, latents=None,
69
+ )
70
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
71
+
72
+ kwargs = {
73
+ "prompt": prompt,
74
+ "height": size,
75
+ "width": size,
76
+ "num_frames": frames,
77
+ "num_inference_steps": steps,
78
+ "seed": seed,
79
+ "guidance_scale": guidance_scale,
80
+ "embedded_guidance_scale": 1.0,
81
+ "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
82
+ "cfg_for": False,
83
+ }
84
+ assert image is not None, "please input image"
85
+ img = load_image(image=image)
86
+ img.resize((size,size), Image.LANCZOS)
87
+ kwargs["image"] = img
88
+ else:
89
+ state_file = f"rv_L_{segment-1}_{seed}.pt"
90
+ state = torch.load(state_file, weights_only=False)
91
+ generator = torch.Generator(device='cuda').manual_seed(seed)
92
+
93
+
94
+ current_latents = latents
95
+
96
+ for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
97
+
98
+ latent_model_input = latents.to(transformer_dtype)
99
+
100
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
101
+ with torch.no_grad():
102
+ noise_pred = self.transformer(
103
+ hidden_states=latent_model_input,
104
+ timestep=timestep,
105
+ encoder_hidden_states=prompt_embeds,
106
+ encoder_attention_mask=prompt_attention_mask,
107
+ pooled_projections=pooled_prompt_embeds,
108
+ guidance=guidance,
109
+ attention_kwargs=attention_kwargs,
110
+ return_dict=False,
111
+ )[0]
112
+
113
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
114
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
115
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
116
+
117
+ else:
118
+ video = latents
119
+ return latents
120
+
121
+ intermediate_latents_cpu = current_latents.detach().cpu()
122
+
123
+ if segment==8:
124
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
125
+ video = self.vae.decode(latents, return_dict=False)[0]
126
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
127
+
128
+ return HunyuanVideoPipelineOutput(frames=video)
129
 
130
+ save_dir = f"./"
131
+ video_out_file = f"{save_dir}/{seed}.mp4"
132
+ print(f"generate video, local path: {video_out_file}")
133
+ export_to_video(output, video_out_file, fps=24)
134
+
135
+ return video_out_file, seed
136
+ else:
137
+ original_prompt_embeds_cpu = prompt_embeds.cpu()
138
+ original_negative_prompt_embeds_cpu = negative_prompt_embeds.cpu()
139
+ original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu()
140
+ original_negative_pooled_prompt_embeds_cpu = negative_pooled_prompt_embeds.cpu()
141
+ original_add_time_ids_cpu = add_time_ids.cpu()
142
+ timesteps = pipe.scheduler.timesteps
143
+ all_timesteps_cpu = timesteps.cpu() # Move to CPU
144
+ state = {
145
+ "intermediate_latents": intermediate_latents_cpu,
146
+ "all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler
147
+ "prompt_embeds": original_prompt_embeds_cpu, # Save ORIGINAL embeds
148
+ "negative_prompt_embeds": original_negative_prompt_embeds_cpu,
149
+ "pooled_prompt_embeds": original_pooled_prompt_embeds_cpu,
150
+ "negative_pooled_prompt_embeds": original_negative_pooled_prompt_embeds_cpu,
151
+ "add_time_ids": original_add_time_ids_cpu, # Save ORIGINAL time IDs
152
+ "guidance_scale": guidance_scale,
153
+ "timesteps_split": timesteps_split_for_state,
154
+ "seed": seed,
155
+ "prompt": prompt, # Save originals for reference/verification
156
+ "negative_prompt": negative_prompt,
157
+ "height": height, # Save dimensions used
158
+ "width": width
159
+ }
160
+ state_file = f"SkyReel_{segment}_{seed}.pt"
161
+ torch.save(state, state_file)
162
+ return None, seed
163
+
164
 
165
+
166
  with gr.Blocks() as demo:
167
+ with gr.Row():
168
+ image = gr.Image(label="Upload Image", type="filepath")
169
+ prompt = gr.Textbox(label="Input Prompt")
170
+ size = gr.Slider(
171
+ label="Size",
172
+ minimum=256,
173
+ maximum=1024,
174
+ step=16,
175
+ value=368,
176
+ )
177
+ frames = gr.Slider(
178
+ label="Number of Frames",
179
+ minimum=16,
180
+ maximum=256,
181
+ step=8,
182
+ value=64,
183
+ )
184
+ steps = gr.Slider(
185
+ label="Number of Steps",
186
+ minimum=1,
187
+ maximum=96,
188
+ step=1,
189
+ value=25,
190
+ )
191
+ guidance_scale = gr.Slider(
192
+ label="Guidance Scale",
193
+ minimum=1.0,
194
+ maximum=16.0,
195
+ step=.1,
196
+ value=6.0,
197
+ )
198
+ submit_button = gr.Button("Generate Video")
199
+ output_video = gr.Video(label="Generated Video")
200
+ range_sliders = []
201
+ for i in range(8):
202
+ slider = gr.Slider(
203
+ minimum=1,
204
+ maximum=250,
205
+ value=[i * (num_inference_steps.value // 8)],
206
+ step=1,
207
+ label=f"Range {i + 1}",
208
  )
209
+ range_sliders.append(slider)
210
+ num_inference_steps.change(
211
+ update_ranges,
212
+ inputs=num_inference_steps,
213
+ outputs=range_sliders,
214
+ )
215
+
216
+ gr.Examples(
217
+ examples=examples,
218
+ inputs=prompt,
219
+ cache_examples=False
220
+ )
221
+ use_negative_prompt.change(
222
+ fn=lambda x: gr.update(visible=x),
223
+ inputs=use_negative_prompt,
224
+ outputs=negative_prompt,
225
+ api_name=False,
226
+ )
227
+ gr.on(
228
+ triggers=[
229
+ run_button_1.click,
230
+ ],
231
+ fn=generate,
232
+ inputs=[
233
+ gr.Number(value=4),
234
+ image,
235
+ prompt,
236
+ size,
237
+ guidance_scale,
238
+ num_inference_steps,
239
+ frames,
240
+ seed,
241
+ ],
242
+ outputs=[result, seed],
243
+ )
244
+ gr.on(
245
+ triggers=[
246
+ run_button_2.click,
247
+ ],
248
+ fn=generate,
249
+ inputs=[
250
+ gr.Number(value=4),
251
+ image,
252
+ prompt,
253
+ size,
254
+ guidance_scale,
255
+ num_inference_steps,
256
+ frames,
257
+ seed,
258
+ ],
259
+ outputs=[result, seed],
260
+ )
261
+ gr.on(
262
+ triggers=[
263
+ run_button_3.click,
264
+ ],
265
+ fn=generate,
266
+ inputs=[
267
+ gr.Number(value=4),
268
+ image,
269
+ prompt,
270
+ size,
271
+ guidance_scale,
272
+ num_inference_steps,
273
+ frames,
274
+ seed,
275
+ ],
276
+ outputs=[result, seed],
277
+ )
278
+ gr.on(
279
+ triggers=[
280
+ run_button_4.click,
281
+ ],
282
+ fn=generate,
283
+ inputs=[
284
+ gr.Number(value=4),
285
+ image,
286
+ prompt,
287
+ size,
288
+ guidance_scale,
289
+ num_inference_steps,
290
+ frames,
291
+ seed,
292
+ ],
293
+ outputs=[result, seed],
294
+ )
295
+ gr.on(
296
+ triggers=[
297
+ run_button_5.click,
298
+ ],
299
+ fn=generate,
300
+ inputs=[
301
+ gr.Number(value=4),
302
+ image,
303
+ prompt,
304
+ size,
305
+ guidance_scale,
306
+ num_inference_steps,
307
+ frames,
308
+ seed,
309
+ ],
310
+ outputs=[result, seed],
311
+ )
312
+ gr.on(
313
+ triggers=[
314
+ run_button_6.click,
315
+ ],
316
+ fn=generate,
317
+ inputs=[
318
+ gr.Number(value=4),
319
+ image,
320
+ prompt,
321
+ size,
322
+ guidance_scale,
323
+ num_inference_steps,
324
+ frames,
325
+ seed,
326
+ ],
327
+ outputs=[result, seed],
328
+ )
329
+ gr.on(
330
+ triggers=[
331
+ run_button_7.click,
332
+ ],
333
+ fn=generate,
334
+ inputs=[
335
+ gr.Number(value=4),
336
+ image,
337
+ prompt,
338
+ size,
339
+ guidance_scale,
340
+ num_inference_steps,
341
+ frames,
342
+ seed,
343
+ ],
344
+ outputs=[result, seed],
345
+ )
346
+ gr.on(
347
+ triggers=[
348
+ run_button_8.click,
349
+ ],
350
+ fn=generate,
351
+ inputs=[
352
+ gr.Number(value=4),
353
+ image,
354
+ prompt,
355
+ size,
356
+ guidance_scale,
357
+ num_inference_steps,
358
+ frames,
359
+ seed,
360
+ ],
361
+ outputs=[result, seed],
362
+ )
363
+
364
 
365
  if __name__ == "__main__":
366
  init_predictor()