fffiloni commited on
Commit
6b5b440
·
verified ·
1 Parent(s): cc5ea83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import os
3
  import torch
4
- from diffusers import CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
5
  from diffusers.utils import export_to_video, load_image
6
  from datetime import datetime
7
 
@@ -22,13 +22,20 @@ hf_hub_download(
22
  local_dir="checkpoints"
23
  )
24
 
25
- pipe = CogVideoXImageToVideoPipeline.from_pretrained(
26
- "THUDM/CogVideoX-5b-I2V",
 
 
 
27
  transformer=CogVideoXTransformer3DModel.from_pretrained(
28
  "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
29
  ),
30
- torch_dtype=torch.bfloat16)
31
-
 
 
 
 
32
  def infer(prompt, image_path, orbit_type, progress=gr.Progress(track_tqdm=True)):
33
  lora_path = "checkpoints/"
34
  if orbit_type == "Left":
@@ -45,7 +52,7 @@ def infer(prompt, image_path, orbit_type, progress=gr.Progress(track_tqdm=True))
45
  image = load_image(image_path)
46
  seed = random.randint(0, 2**8 - 1)
47
 
48
- video = pipe(
49
  image,
50
  prompt,
51
  num_inference_steps=50, # NOT Changed
 
1
  import gradio as gr
2
  import os
3
  import torch
4
+ from diffusers import CogVideoXPipeline, CogVideoXDPMScheduler, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
5
  from diffusers.utils import export_to_video, load_image
6
  from datetime import datetime
7
 
 
22
  local_dir="checkpoints"
23
  )
24
 
25
+ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device)
26
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
27
+
28
+ pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
29
+ "THUDM/CogVideoX-5b-I2V",
30
  transformer=CogVideoXTransformer3DModel.from_pretrained(
31
  "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
32
  ),
33
+ vae=pipe.vae,
34
+ scheduler=pipe.scheduler,
35
+ tokenizer=pipe.tokenizer,
36
+ text_encoder=pipe.text_encoder,
37
+ torch_dtype=torch.bfloat16,
38
+ )
39
  def infer(prompt, image_path, orbit_type, progress=gr.Progress(track_tqdm=True)):
40
  lora_path = "checkpoints/"
41
  if orbit_type == "Left":
 
52
  image = load_image(image_path)
53
  seed = random.randint(0, 2**8 - 1)
54
 
55
+ video = pipe_image(
56
  image,
57
  prompt,
58
  num_inference_steps=50, # NOT Changed