Spaces:
Sleeping
Sleeping
import os, random | |
from einops import rearrange, repeat | |
import torch | |
from utils.utils import instantiate_from_config | |
from lvdm.models.ddpm3d import LatentDiffusion | |
from lvdm.models.samplers.ddim import DDIMSampler | |
from lvdm.modules.attention import TemporalTransformer | |
class T2VAdapterDepth(LatentDiffusion): | |
def __init__(self, depth_stage_config, adapter_config, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.depth_stage = instantiate_from_config(depth_stage_config) | |
self.adapter = instantiate_from_config(adapter_config) | |
self.condtype = adapter_config.cond_name | |
if 'pretrained' in adapter_config: | |
self.load_pretrained_adapter(adapter_config.pretrained) | |
for param in self.depth_stage.parameters(): | |
param.requires_grad = False | |
def prepare_midas_input(self, x): | |
# x: (b, c, h, w) | |
h, w = x.shape[-2:] | |
x_midas = torch.nn.functional.interpolate(x, size=(h, w), mode='bilinear') | |
return x_midas | |
def get_batch_depth(self, x, target_size): | |
# x: (b, c, t, h, w) | |
# get depth image, reshape to target_size and normalize to [-1, 1] | |
b, c, t, h, w = x.shape | |
x = rearrange(x, 'b c t h w -> (b t) c h w') | |
x_midas = self.prepare_midas_input(x) | |
cond_depth = self.depth_stage(x_midas) | |
cond_depth = torch.nn.functional.interpolate(cond_depth, size=target_size, mode='bilinear') | |
depth_min, depth_max = torch.amin(cond_depth, dim=[1, 2, 3], keepdim=True), torch.amax(cond_depth, dim=[1, 2, 3], keepdim=True) | |
cond_depth = (cond_depth - depth_min) / (depth_max - depth_min + 1e-7) | |
cond_depth = 2. * cond_depth - 1. | |
cond_depth = rearrange(cond_depth, '(b t) c h w -> b c t h w', b=b, t=t) | |
return cond_depth | |
def load_pretrained_adapter(self, adapter_ckpt): | |
# load pretrained adapter | |
print(">>> Load pretrained adapter checkpoint.") | |
try: | |
state_dict = torch.load(adapter_ckpt, map_location="cpu") | |
if "state_dict" in list(state_dict.keys()): | |
state_dict = state_dict["state_dict"] | |
self.adapter.load_state_dict(state_dict, strict=True) | |
except: | |
state_dict = torch.load(adapter_ckpt, map_location=f"cpu") | |
if "state_dict" in list(state_dict.keys()): | |
state_dict = state_dict["state_dict"] | |
model_state_dict = self.adapter.state_dict() | |
n_unmatched = 0 | |
for n, p in model_state_dict.items(): | |
if p.shape != state_dict[n].shape: | |
state_dict.pop(n) | |
n_unmatched += 1 | |
model_state_dict.update(state_dict) | |
self.adapter.load_state_dict(model_state_dict) | |
print(f"Pretrained adapter IS NOT complete [{n_unmatched} units have unmatched shape].") | |
class T2IAdapterStyleAS(LatentDiffusion): | |
def __init__(self, style_stage_config, adapter_config, *args, **kwargs): | |
super(T2IAdapterStyleAS, self).__init__(*args, **kwargs) | |
self.adapter = instantiate_from_config(adapter_config) | |
self.condtype = adapter_config.cond_name | |
## adapter loading / saving paths | |
self.style_stage_model = instantiate_from_config(style_stage_config) | |
self.adapter.create_cross_attention_adapter(self.model.diffusion_model) | |
if 'pretrained' in adapter_config: | |
self.load_pretrained_adapter(adapter_config.pretrained) | |
# freeze the style stage model | |
for param in self.style_stage_model.parameters(): | |
param.requires_grad = False | |
def load_pretrained_adapter(self, pretrained): | |
state_dict = torch.load(pretrained, map_location=f"cpu") | |
if "state_dict" in list(state_dict.keys()): | |
state_dict = state_dict["state_dict"] | |
self.adapter.load_state_dict(state_dict, strict=False) | |
print('>>> adapter checkpoint loaded.') | |
def get_batch_style(self, batch_x): | |
b, c, h, w = batch_x.shape | |
cond_style = self.style_stage_model(batch_x) | |
return cond_style | |
class T2VFintoneStyleAS(T2IAdapterStyleAS): | |
def _get_temp_attn_parameters(self): | |
temp_attn_params = [] | |
def register_recr(net_, name): | |
if isinstance(net_, TemporalTransformer): | |
temp_attn_params.extend(net_.parameters()) | |
else: | |
for sub_name, net in net_.named_children(): | |
register_recr(net, f"{name}.{sub_name}") | |
for name, net in self.model.diffusion_model.named_children(): | |
register_recr(net, name) | |
return temp_attn_params | |
def _get_temp_attn_state_dict(self): | |
temp_attn_state_dict = {} | |
def register_recr(net_, name): | |
if isinstance(net_, TemporalTransformer): | |
temp_attn_state_dict[name] = net_.state_dict() | |
else: | |
for sub_name, net in net_.named_children(): | |
register_recr(net, f"{name}.{sub_name}") | |
for name, net in self.model.diffusion_model.named_children(): | |
register_recr(net, name) | |
return temp_attn_state_dict | |
def _load_temp_attn_state_dict(self, temp_attn_state_dict): | |
def register_recr(net_, name): | |
if isinstance(net_, TemporalTransformer): | |
net_.load_state_dict(temp_attn_state_dict[name], strict=True) | |
else: | |
for sub_name, net in net_.named_children(): | |
register_recr(net, f"{name}.{sub_name}") | |
for name, net in self.model.diffusion_model.named_children(): | |
register_recr(net, name) | |
def load_pretrained_temporal(self, pretrained): | |
temp_attn_ckpt = torch.load(pretrained, map_location=f"cpu") | |
if "state_dict" in list(temp_attn_ckpt.keys()): | |
temp_attn_ckpt = temp_attn_ckpt["state_dict"] | |
self._load_temp_attn_state_dict(temp_attn_ckpt) | |
print('>>> Temporal Attention checkpoint loaded.') |