fffiloni commited on
Commit
9dab6c2
·
verified ·
1 Parent(s): a78ae85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
  import os
3
  import torch
4
- from diffusers import CogVideoXImageToVideoPipeline
5
  from diffusers.utils import export_to_video, load_image
 
 
6
  from datetime import datetime
7
 
8
  from huggingface_hub import hf_hub_download
@@ -22,15 +24,25 @@ hf_hub_download(
22
  local_dir="checkpoints"
23
  )
24
 
25
- pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
 
 
 
 
 
 
26
  lora_path = "your lora path"
27
  lora_rank = 256
 
28
  def infer(prompt, image_path, orbit_type, progress=gr.Progress(track_tqdm=True)):
29
  lora_path = "checkpoints/"
 
30
  if orbit_type == "Left":
31
  weight_name = "orbit_left_lora_weights.safetensors"
 
32
  elif orbit_type == "Up":
33
  weight_name = "orbit_up_lora_weights.safetensors"
 
34
  lora_rank = 256
35
  pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name="test_1")
36
  pipe.fuse_lora(lora_scale=1 / lora_rank)
 
1
  import gradio as gr
2
  import os
3
  import torch
4
+ from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel
5
  from diffusers.utils import export_to_video, load_image
6
+ from transformers import T5EncoderModel, T5Tokenizer
7
+
8
  from datetime import datetime
9
 
10
  from huggingface_hub import hf_hub_download
 
24
  local_dir="checkpoints"
25
  )
26
 
27
+
28
+ model_id = "THUDM/CogVideoX-5b-I2V"
29
+ transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
30
+ text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float16)
31
+ vae = AutoencoderKLCogVideoX.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16)
32
+ tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
33
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, torch_dtype=torch.float16)
34
  lora_path = "your lora path"
35
  lora_rank = 256
36
+
37
  def infer(prompt, image_path, orbit_type, progress=gr.Progress(track_tqdm=True)):
38
  lora_path = "checkpoints/"
39
+ adapter_name = None
40
  if orbit_type == "Left":
41
  weight_name = "orbit_left_lora_weights.safetensors"
42
+ adapter_name = "orbit_left_lora_weights"
43
  elif orbit_type == "Up":
44
  weight_name = "orbit_up_lora_weights.safetensors"
45
+ adapter_name = "orbit_up_lora_weights"
46
  lora_rank = 256
47
  pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name="test_1")
48
  pipe.fuse_lora(lora_scale=1 / lora_rank)