liuhuohuo's picture
init
5af269e
import torch
import torch.nn as nn
from collections import OrderedDict
from lvdm.basics import (
zero_module,
conv_nd,
avg_pool_nd
)
from einops import rearrange
from lvdm.modules.attention import register_attn_processor, set_attn_processor, DualCrossAttnProcessor, get_attn_processor
from lvdm.modules.attention import DualCrossAttnProcessorAS
from utils.utils import instantiate_from_config
from lvdm.modules.encoders.arch_transformer import Transformer
class StyleTransformer(nn.Module):
def __init__(self, in_dim=1024, out_dim=1024, num_heads=8, num_tokens=4, n_layers=2):
super().__init__()
scale = in_dim ** -0.5
self.num_tokens = num_tokens
self.style_emb = nn.Parameter(torch.randn(1, num_tokens, in_dim) * scale)
self.transformer_blocks = Transformer(
width=in_dim,
layers=n_layers,
heads=num_heads,
)
self.ln1 = nn.LayerNorm(in_dim)
self.ln2 = nn.LayerNorm(in_dim)
self.proj = nn.Parameter(torch.randn(in_dim, out_dim) * scale)
def forward(self, x):
style_emb = self.style_emb.repeat(x.shape[0], 1, 1)
x = torch.cat([style_emb, x], dim=1)
# x = torch.cat([x, style_emb], dim=1)
x = self.ln1(x)
x = x.permute(1, 0, 2)
x = self.transformer_blocks(x)
x = x.permute(1, 0, 2)
x = self.ln2(x[:, :self.num_tokens, :])
x = x @ self.proj
return x
class ScaleEncoder(nn.Module):
def __init__(self, in_dim=1024, out_dim=1, num_heads=8, num_tokens=16, n_layers=2):
super().__init__()
scale = in_dim ** -0.5
self.num_tokens = num_tokens
self.scale_emb = nn.Parameter(torch.randn(1, num_tokens, in_dim) * scale)
self.transformer_blocks = Transformer(
width=in_dim,
layers=n_layers,
heads=num_heads,
)
self.ln1 = nn.LayerNorm(in_dim)
self.ln2 = nn.LayerNorm(in_dim)
self.out = nn.Sequential(
nn.Linear(in_dim, 32),
nn.GELU(),
nn.Linear(32, out_dim),
nn.Tanh(),
)
def forward(self, x):
scale_emb = self.scale_emb.repeat(x.shape[0], 1, 1)
x = torch.cat([scale_emb, x], dim=1)
x = self.ln1(x)
x = x.permute(1, 0, 2)
x = self.transformer_blocks(x)
x = x.permute(1, 0, 2)
x = self.ln2(x[:, :self.num_tokens, :])
x = self.out(x)
return x
class DropPath(nn.Module):
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
"""
def __init__(self, p):
super(DropPath, self).__init__()
self.p = p
def forward(self, *args, zero=None, keep=None):
if not self.training:
return args[0] if len(args) == 1 else args
# params
x = args[0]
b = x.size(0)
n = (torch.rand(b) < self.p).sum()
# non-zero and non-keep mask
mask = x.new_ones(b, dtype=torch.bool)
if keep is not None:
mask[keep] = False
if zero is not None:
mask[zero] = False
# drop-path index
index = torch.where(mask)[0]
index = index[torch.randperm(len(index))[:n]]
if zero is not None:
index = torch.cat([index, torch.where(zero)[0]], dim=0)
# drop-path multiplier
multiplier = x.new_ones(b)
multiplier[index] = 0.0
output = tuple(u * self.broadcast(multiplier, u) for u in args)
return output[0] if len(args) == 1 else output
def broadcast(self, src, dst):
assert src.size(0) == dst.size(0)
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
return src.view(shape)
class ImageContext(nn.Module):
def __init__(self, width=1024, context_dim=768, token_num=1):
super().__init__()
self.width = width
self.token_num = token_num
self.context_dim = context_dim
self.fc = nn.Sequential(
nn.Linear(context_dim, width),
nn.SiLU(),
nn.Linear(width, token_num * context_dim),
)
self.drop_path = DropPath(0.5)
def forward(self, x):
# x shape [B, C]
out = self.drop_path(self.fc(x))
out = rearrange(out, 'b (n c) -> b n c', n=self.token_num)
return out
class StyleAdapterDualAttnAS(nn.Module):
def __init__(self, image_context_config, scale_predictor_config, scale=1.0, use_norm=False, time_embed_dim=1024, mid_dim=32):
super().__init__()
self.image_context_model = instantiate_from_config(image_context_config)
self.scale_predictor = instantiate_from_config(scale_predictor_config)
self.scale = scale
self.use_norm = use_norm
self.time_embed_dim = time_embed_dim
self.mid_dim = mid_dim
def create_cross_attention_adapter(self, unet):
ori_processor = register_attn_processor(unet)
dual_attn_processor = {}
for idx, key in enumerate(ori_processor.keys()):
kv_state_dicts = {
'k': {'weight': unet.state_dict()[key[:-10] + '.to_k.weight']},
'v': {'weight': unet.state_dict()[key[:-10] + '.to_v.weight']},
}
context_dim = kv_state_dicts['k']['weight'].shape[1]
inner_dim = kv_state_dicts['k']['weight'].shape[0]
print(key, context_dim, inner_dim)
dual_attn_processor[key] = DualCrossAttnProcessorAS(
context_dim=context_dim,
inner_dim=inner_dim,
state_dict=kv_state_dicts,
scale=self.scale,
use_norm=self.use_norm,
layer_idx=idx,
)
set_attn_processor(unet, dual_attn_processor)
dual_attn_processor = {key.replace('.', '_'): value for key, value in dual_attn_processor.items()}
self.add_module('kv_attn_layers', nn.ModuleDict(dual_attn_processor))
def set_cross_attention_adapter(self, unet):
dual_attn_processor = get_attn_processor(unet)
for key in dual_attn_processor.keys():
module_key = key.replace('.', '_')
dual_attn_processor[key] = self.kv_attn_layers[module_key]
print('set', key, module_key)
set_attn_processor(unet, dual_attn_processor)
def forward(self, x):
# x shape [B, C]
return self.image_context_model(x)