StyleCrafter / lvdm /models /ddpm3d_cond.py
liuhuohuo's picture
init
5af269e
raw
history blame
6.16 kB
import os, random
from einops import rearrange, repeat
import torch
from utils.utils import instantiate_from_config
from lvdm.models.ddpm3d import LatentDiffusion
from lvdm.models.samplers.ddim import DDIMSampler
from lvdm.modules.attention import TemporalTransformer
class T2VAdapterDepth(LatentDiffusion):
def __init__(self, depth_stage_config, adapter_config, *args, **kwargs):
super().__init__(*args, **kwargs)
self.depth_stage = instantiate_from_config(depth_stage_config)
self.adapter = instantiate_from_config(adapter_config)
self.condtype = adapter_config.cond_name
if 'pretrained' in adapter_config:
self.load_pretrained_adapter(adapter_config.pretrained)
for param in self.depth_stage.parameters():
param.requires_grad = False
def prepare_midas_input(self, x):
# x: (b, c, h, w)
h, w = x.shape[-2:]
x_midas = torch.nn.functional.interpolate(x, size=(h, w), mode='bilinear')
return x_midas
@torch.no_grad()
def get_batch_depth(self, x, target_size):
# x: (b, c, t, h, w)
# get depth image, reshape to target_size and normalize to [-1, 1]
b, c, t, h, w = x.shape
x = rearrange(x, 'b c t h w -> (b t) c h w')
x_midas = self.prepare_midas_input(x)
cond_depth = self.depth_stage(x_midas)
cond_depth = torch.nn.functional.interpolate(cond_depth, size=target_size, mode='bilinear')
depth_min, depth_max = torch.amin(cond_depth, dim=[1, 2, 3], keepdim=True), torch.amax(cond_depth, dim=[1, 2, 3], keepdim=True)
cond_depth = (cond_depth - depth_min) / (depth_max - depth_min + 1e-7)
cond_depth = 2. * cond_depth - 1.
cond_depth = rearrange(cond_depth, '(b t) c h w -> b c t h w', b=b, t=t)
return cond_depth
def load_pretrained_adapter(self, adapter_ckpt):
# load pretrained adapter
print(">>> Load pretrained adapter checkpoint.")
try:
state_dict = torch.load(adapter_ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
self.adapter.load_state_dict(state_dict, strict=True)
except:
state_dict = torch.load(adapter_ckpt, map_location=f"cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
model_state_dict = self.adapter.state_dict()
n_unmatched = 0
for n, p in model_state_dict.items():
if p.shape != state_dict[n].shape:
state_dict.pop(n)
n_unmatched += 1
model_state_dict.update(state_dict)
self.adapter.load_state_dict(model_state_dict)
print(f"Pretrained adapter IS NOT complete [{n_unmatched} units have unmatched shape].")
class T2IAdapterStyleAS(LatentDiffusion):
def __init__(self, style_stage_config, adapter_config, *args, **kwargs):
super(T2IAdapterStyleAS, self).__init__(*args, **kwargs)
self.adapter = instantiate_from_config(adapter_config)
self.condtype = adapter_config.cond_name
## adapter loading / saving paths
self.style_stage_model = instantiate_from_config(style_stage_config)
self.adapter.create_cross_attention_adapter(self.model.diffusion_model)
if 'pretrained' in adapter_config:
self.load_pretrained_adapter(adapter_config.pretrained)
# freeze the style stage model
for param in self.style_stage_model.parameters():
param.requires_grad = False
def load_pretrained_adapter(self, pretrained):
state_dict = torch.load(pretrained, map_location=f"cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
self.adapter.load_state_dict(state_dict, strict=False)
print('>>> adapter checkpoint loaded.')
@torch.no_grad()
def get_batch_style(self, batch_x):
b, c, h, w = batch_x.shape
cond_style = self.style_stage_model(batch_x)
return cond_style
class T2VFintoneStyleAS(T2IAdapterStyleAS):
def _get_temp_attn_parameters(self):
temp_attn_params = []
def register_recr(net_, name):
if isinstance(net_, TemporalTransformer):
temp_attn_params.extend(net_.parameters())
else:
for sub_name, net in net_.named_children():
register_recr(net, f"{name}.{sub_name}")
for name, net in self.model.diffusion_model.named_children():
register_recr(net, name)
return temp_attn_params
def _get_temp_attn_state_dict(self):
temp_attn_state_dict = {}
def register_recr(net_, name):
if isinstance(net_, TemporalTransformer):
temp_attn_state_dict[name] = net_.state_dict()
else:
for sub_name, net in net_.named_children():
register_recr(net, f"{name}.{sub_name}")
for name, net in self.model.diffusion_model.named_children():
register_recr(net, name)
return temp_attn_state_dict
def _load_temp_attn_state_dict(self, temp_attn_state_dict):
def register_recr(net_, name):
if isinstance(net_, TemporalTransformer):
net_.load_state_dict(temp_attn_state_dict[name], strict=True)
else:
for sub_name, net in net_.named_children():
register_recr(net, f"{name}.{sub_name}")
for name, net in self.model.diffusion_model.named_children():
register_recr(net, name)
def load_pretrained_temporal(self, pretrained):
temp_attn_ckpt = torch.load(pretrained, map_location=f"cpu")
if "state_dict" in list(temp_attn_ckpt.keys()):
temp_attn_ckpt = temp_attn_ckpt["state_dict"]
self._load_temp_attn_state_dict(temp_attn_ckpt)
print('>>> Temporal Attention checkpoint loaded.')