Switti-1024 / models /__init__.py
realantonvoronov
init commit
55ca09f
raw
history blame
2.42 kB
import torch.nn as nn
from .clip import FrozenCLIPEmbedder
from .switti import Switti
from .vqvae import VQVAE
from .pipeline import SwittiPipeline
def build_models(
# Shared args
device,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
# VQVAE args
V=4096,
Cvae=32,
ch=160,
share_quant_resi=4,
# Switti args
depth=16,
rope=True,
rope_theta=10000,
rope_size=128,
use_swiglu_ffn=True,
use_ar=False,
use_crop_cond=True,
attn_l2_norm=True,
init_adaln=0.5,
init_adaln_gamma=1e-5,
init_head=0.02,
init_std=-1, # init_std < 0: automated
drop_rate=0.0,
attn_drop_rate=0.0,
dpr=0,
norm_eps=1e-6,
# pipeline args
text_encoder_path="openai/clip-vit-large-patch14",
text_encoder_2_path="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
) -> tuple[VQVAE, Switti]:
heads = depth
width = depth * 64
if dpr > 0:
dpr = dpr * depth / 24
# disable built-in initialization for speed
for clz in (
nn.Linear,
nn.LayerNorm,
nn.BatchNorm2d,
nn.SyncBatchNorm,
nn.Conv1d,
nn.Conv2d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
):
setattr(clz, "reset_parameters", lambda self: None)
# build models
vae_local = VQVAE(
vocab_size=V,
z_channels=Cvae,
ch=ch,
test_mode=True,
share_quant_resi=share_quant_resi,
v_patch_nums=patch_nums,
).to(device)
switti_wo_ddp = Switti(
depth=depth,
embed_dim=width,
num_heads=heads,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr,
norm_eps=norm_eps,
attn_l2_norm=attn_l2_norm,
patch_nums=patch_nums,
rope=rope,
rope_theta=rope_theta,
rope_size=rope_size,
use_swiglu_ffn=use_swiglu_ffn,
use_ar=use_ar,
use_crop_cond=use_crop_cond,
).to(device)
switti_wo_ddp.init_weights(
init_adaln=init_adaln,
init_adaln_gamma=init_adaln_gamma,
init_head=init_head,
init_std=init_std,
)
text_encoder = FrozenCLIPEmbedder(text_encoder_path)
text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path)
pipe = SwittiPipeline(switti_wo_ddp, vae_local, text_encoder, text_encoder_2, device)
return vae_local, switti_wo_ddp, pipe