jhj0517
commited on
Commit
•
4a16e03
1
Parent(s):
1f6f578
add `init_model()` to musepose_inference.py
Browse files- musepose_inference.py +46 -36
musepose_inference.py
CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
|
|
4 |
import torch
|
5 |
from diffusers import AutoencoderKL, DDIMScheduler
|
6 |
from einops import repeat
|
7 |
-
from omegaconf import OmegaConf
|
8 |
from PIL import Image
|
9 |
from torchvision import transforms
|
10 |
from transformers import CLIPVisionModelWithProjection
|
@@ -94,33 +94,9 @@ class MusePoseInference:
|
|
94 |
else:
|
95 |
weight_dtype = torch.float32
|
96 |
|
97 |
-
self.vae = AutoencoderKL.from_pretrained(
|
98 |
-
self.image_gen_model_paths["pretrained_vae"],
|
99 |
-
).to("cuda", dtype=weight_dtype)
|
100 |
-
|
101 |
-
self.reference_unet = UNet2DConditionModel.from_pretrained(
|
102 |
-
self.image_gen_model_paths["pretrained_base_model"],
|
103 |
-
subfolder="unet",
|
104 |
-
).to(dtype=weight_dtype, device="cuda")
|
105 |
-
|
106 |
inference_config_path = self.inference_config_path
|
107 |
infer_config = OmegaConf.load(inference_config_path)
|
108 |
|
109 |
-
self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
110 |
-
Path(self.image_gen_model_paths["pretrained_base_model"]),
|
111 |
-
Path(self.musepose_model_paths["motion_module"]),
|
112 |
-
subfolder="unet",
|
113 |
-
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
114 |
-
).to(dtype=weight_dtype, device="cuda")
|
115 |
-
|
116 |
-
self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
|
117 |
-
dtype=weight_dtype, device="cuda"
|
118 |
-
)
|
119 |
-
|
120 |
-
self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
|
121 |
-
self.image_gen_model_paths["image_encoder"]
|
122 |
-
).to(dtype=weight_dtype, device="cuda")
|
123 |
-
|
124 |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
|
125 |
scheduler = DDIMScheduler(**sched_kwargs)
|
126 |
|
@@ -128,17 +104,8 @@ class MusePoseInference:
|
|
128 |
|
129 |
width, height = W, H
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
torch.load(self.musepose_model_paths["denoising_unet"], map_location="cpu"),
|
134 |
-
strict=False,
|
135 |
-
)
|
136 |
-
self.reference_unet.load_state_dict(
|
137 |
-
torch.load(self.musepose_model_paths["reference_unet"], map_location="cpu"),
|
138 |
-
)
|
139 |
-
self.pose_guider.load_state_dict(
|
140 |
-
torch.load(self.musepose_model_paths["pose_guider"], map_location="cpu"),
|
141 |
-
)
|
142 |
self.pipe = Pose2VideoPipeline(
|
143 |
vae=self.vae,
|
144 |
image_encoder=self.image_enc,
|
@@ -225,6 +192,49 @@ class MusePoseInference:
|
|
225 |
self.release_vram()
|
226 |
return output_path, output_path_demo
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
def release_vram(self):
|
229 |
models = [
|
230 |
'vae', 'reference_unet', 'denoising_unet',
|
|
|
4 |
import torch
|
5 |
from diffusers import AutoencoderKL, DDIMScheduler
|
6 |
from einops import repeat
|
7 |
+
from omegaconf import OmegaConf, DictConfig
|
8 |
from PIL import Image
|
9 |
from torchvision import transforms
|
10 |
from transformers import CLIPVisionModelWithProjection
|
|
|
94 |
else:
|
95 |
weight_dtype = torch.float32
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
inference_config_path = self.inference_config_path
|
98 |
infer_config = OmegaConf.load(inference_config_path)
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
|
101 |
scheduler = DDIMScheduler(**sched_kwargs)
|
102 |
|
|
|
104 |
|
105 |
width, height = W, H
|
106 |
|
107 |
+
self.init_model(weight_dtype=weight_dtype, infer_config=infer_config)
|
108 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
self.pipe = Pose2VideoPipeline(
|
110 |
vae=self.vae,
|
111 |
image_encoder=self.image_enc,
|
|
|
192 |
self.release_vram()
|
193 |
return output_path, output_path_demo
|
194 |
|
195 |
+
def init_model(self,
|
196 |
+
weight_dtype: torch.dtype,
|
197 |
+
infer_config: DictConfig
|
198 |
+
):
|
199 |
+
if self.vae is None:
|
200 |
+
self.vae = AutoencoderKL.from_pretrained(
|
201 |
+
self.image_gen_model_paths["pretrained_vae"],
|
202 |
+
).to("cuda", dtype=weight_dtype)
|
203 |
+
|
204 |
+
if self.reference_unet is None:
|
205 |
+
self.reference_unet = UNet2DConditionModel.from_pretrained(
|
206 |
+
self.image_gen_model_paths["pretrained_base_model"],
|
207 |
+
subfolder="unet",
|
208 |
+
).to(dtype=weight_dtype, device="cuda")
|
209 |
+
self.reference_unet.load_state_dict(
|
210 |
+
torch.load(self.musepose_model_paths["reference_unet"], map_location="cpu"),
|
211 |
+
)
|
212 |
+
|
213 |
+
if self.denoising_unet is None:
|
214 |
+
self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
215 |
+
Path(self.image_gen_model_paths["pretrained_base_model"]),
|
216 |
+
Path(self.musepose_model_paths["motion_module"]),
|
217 |
+
subfolder="unet",
|
218 |
+
unet_additional_kwargs=infer_config.unet_additional_kwargs,
|
219 |
+
).to(dtype=weight_dtype, device="cuda")
|
220 |
+
self.denoising_unet.load_state_dict(
|
221 |
+
torch.load(self.musepose_model_paths["denoising_unet"], map_location="cpu"),
|
222 |
+
strict=False,
|
223 |
+
)
|
224 |
+
|
225 |
+
if self.pose_guider is None:
|
226 |
+
self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
|
227 |
+
dtype=weight_dtype, device="cuda"
|
228 |
+
)
|
229 |
+
self.pose_guider.load_state_dict(
|
230 |
+
torch.load(self.musepose_model_paths["pose_guider"], map_location="cpu"),
|
231 |
+
)
|
232 |
+
|
233 |
+
if self.image_enc is None:
|
234 |
+
self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
|
235 |
+
self.image_gen_model_paths["image_encoder"]
|
236 |
+
).to(dtype=weight_dtype, device="cuda")
|
237 |
+
|
238 |
def release_vram(self):
|
239 |
models = [
|
240 |
'vae', 'reference_unet', 'denoising_unet',
|