File size: 1,575 Bytes
5c162ac 6ac94e3 5c162ac 6ac94e3 5c162ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from transformers import PretrainedConfig
class STDiT2Config(PretrainedConfig):
model_type = "stdit2"
def __init__(
self,
input_size=(None, None, None),
input_sq_size=32,
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
freeze=None,
qk_norm=False,
enable_flash_attn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
**kwargs,
):
self.input_size = input_size
self.input_sq_size = input_sq_size
self.in_channels = in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.class_dropout_prob = class_dropout_prob
self.pred_sigma = pred_sigma
self.drop_path = drop_path
self.no_temporal_pos_emb = no_temporal_pos_emb
self.caption_channels = caption_channels
self.model_max_length = model_max_length
self.freeze = freeze
self.qk_norm = qk_norm
self.enable_flash_attn = enable_flash_attn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.enable_sequence_parallelism = enable_sequence_parallelism
super().__init__(**kwargs) |