Spaces:
Sleeping
Sleeping
File size: 6,162 Bytes
5af269e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os, random
from einops import rearrange, repeat
import torch
from utils.utils import instantiate_from_config
from lvdm.models.ddpm3d import LatentDiffusion
from lvdm.models.samplers.ddim import DDIMSampler
from lvdm.modules.attention import TemporalTransformer
class T2VAdapterDepth(LatentDiffusion):
def __init__(self, depth_stage_config, adapter_config, *args, **kwargs):
super().__init__(*args, **kwargs)
self.depth_stage = instantiate_from_config(depth_stage_config)
self.adapter = instantiate_from_config(adapter_config)
self.condtype = adapter_config.cond_name
if 'pretrained' in adapter_config:
self.load_pretrained_adapter(adapter_config.pretrained)
for param in self.depth_stage.parameters():
param.requires_grad = False
def prepare_midas_input(self, x):
# x: (b, c, h, w)
h, w = x.shape[-2:]
x_midas = torch.nn.functional.interpolate(x, size=(h, w), mode='bilinear')
return x_midas
@torch.no_grad()
def get_batch_depth(self, x, target_size):
# x: (b, c, t, h, w)
# get depth image, reshape to target_size and normalize to [-1, 1]
b, c, t, h, w = x.shape
x = rearrange(x, 'b c t h w -> (b t) c h w')
x_midas = self.prepare_midas_input(x)
cond_depth = self.depth_stage(x_midas)
cond_depth = torch.nn.functional.interpolate(cond_depth, size=target_size, mode='bilinear')
depth_min, depth_max = torch.amin(cond_depth, dim=[1, 2, 3], keepdim=True), torch.amax(cond_depth, dim=[1, 2, 3], keepdim=True)
cond_depth = (cond_depth - depth_min) / (depth_max - depth_min + 1e-7)
cond_depth = 2. * cond_depth - 1.
cond_depth = rearrange(cond_depth, '(b t) c h w -> b c t h w', b=b, t=t)
return cond_depth
def load_pretrained_adapter(self, adapter_ckpt):
# load pretrained adapter
print(">>> Load pretrained adapter checkpoint.")
try:
state_dict = torch.load(adapter_ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
self.adapter.load_state_dict(state_dict, strict=True)
except:
state_dict = torch.load(adapter_ckpt, map_location=f"cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
model_state_dict = self.adapter.state_dict()
n_unmatched = 0
for n, p in model_state_dict.items():
if p.shape != state_dict[n].shape:
state_dict.pop(n)
n_unmatched += 1
model_state_dict.update(state_dict)
self.adapter.load_state_dict(model_state_dict)
print(f"Pretrained adapter IS NOT complete [{n_unmatched} units have unmatched shape].")
class T2IAdapterStyleAS(LatentDiffusion):
def __init__(self, style_stage_config, adapter_config, *args, **kwargs):
super(T2IAdapterStyleAS, self).__init__(*args, **kwargs)
self.adapter = instantiate_from_config(adapter_config)
self.condtype = adapter_config.cond_name
## adapter loading / saving paths
self.style_stage_model = instantiate_from_config(style_stage_config)
self.adapter.create_cross_attention_adapter(self.model.diffusion_model)
if 'pretrained' in adapter_config:
self.load_pretrained_adapter(adapter_config.pretrained)
# freeze the style stage model
for param in self.style_stage_model.parameters():
param.requires_grad = False
def load_pretrained_adapter(self, pretrained):
state_dict = torch.load(pretrained, map_location=f"cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
self.adapter.load_state_dict(state_dict, strict=False)
print('>>> adapter checkpoint loaded.')
@torch.no_grad()
def get_batch_style(self, batch_x):
b, c, h, w = batch_x.shape
cond_style = self.style_stage_model(batch_x)
return cond_style
class T2VFintoneStyleAS(T2IAdapterStyleAS):
def _get_temp_attn_parameters(self):
temp_attn_params = []
def register_recr(net_, name):
if isinstance(net_, TemporalTransformer):
temp_attn_params.extend(net_.parameters())
else:
for sub_name, net in net_.named_children():
register_recr(net, f"{name}.{sub_name}")
for name, net in self.model.diffusion_model.named_children():
register_recr(net, name)
return temp_attn_params
def _get_temp_attn_state_dict(self):
temp_attn_state_dict = {}
def register_recr(net_, name):
if isinstance(net_, TemporalTransformer):
temp_attn_state_dict[name] = net_.state_dict()
else:
for sub_name, net in net_.named_children():
register_recr(net, f"{name}.{sub_name}")
for name, net in self.model.diffusion_model.named_children():
register_recr(net, name)
return temp_attn_state_dict
def _load_temp_attn_state_dict(self, temp_attn_state_dict):
def register_recr(net_, name):
if isinstance(net_, TemporalTransformer):
net_.load_state_dict(temp_attn_state_dict[name], strict=True)
else:
for sub_name, net in net_.named_children():
register_recr(net, f"{name}.{sub_name}")
for name, net in self.model.diffusion_model.named_children():
register_recr(net, name)
def load_pretrained_temporal(self, pretrained):
temp_attn_ckpt = torch.load(pretrained, map_location=f"cpu")
if "state_dict" in list(temp_attn_ckpt.keys()):
temp_attn_ckpt = temp_attn_ckpt["state_dict"]
self._load_temp_attn_state_dict(temp_attn_ckpt)
print('>>> Temporal Attention checkpoint loaded.') |