szyezhu's picture
Upload 46 files
019d164
from .unet import UNetModel
NUM_CLASSES = 1000
AFHQ_DICT = dict(
attention_resolutions="16",
class_cond=False,
dropout=0.0,
image_size=256,
learn_sigma=True,
num_channels=128,
num_head_channels=64,
num_res_blocks=1,
resblock_updown=True,
use_fp16=False,
use_scale_shift_norm=True,
num_heads=4,
num_heads_upsample=-1,
channel_mult="",
use_checkpoint=False,
use_new_attention_order=False,
)
IMAGENET_DICT = dict(
attention_resolutions="32,16,8",
class_cond=True,
image_size=512,
learn_sigma=True,
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=False,
use_scale_shift_norm=True,
dropout=0.0,
num_heads=4,
num_heads_upsample=-1,
channel_mult="",
use_checkpoint=False,
use_new_attention_order=False,
)
def create_model(
image_size,
num_channels,
num_res_blocks,
channel_mult="",
learn_sigma=False,
class_cond=False,
use_checkpoint=False,
attention_resolutions="16",
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
dropout=0,
resblock_updown=False,
use_fp16=False,
use_new_attention_order=False,
):
if channel_mult == "":
if image_size == 512:
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
elif image_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 128:
channel_mult = (1, 1, 2, 3, 4)
elif image_size == 64:
channel_mult = (1, 2, 3, 4)
else:
raise ValueError(f"unsupported image size: {image_size}")
else:
channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
attention_ds = []
for res in attention_resolutions.split(","):
attention_ds.append(image_size // int(res))
return UNetModel(
image_size=image_size,
in_channels=3,
model_channels=num_channels,
out_channels=(3 if not learn_sigma else 6),
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
use_new_attention_order=use_new_attention_order,
)
def i_DDPM(dataset_name = 'AFHQ'):
if dataset_name in ['AFHQ', 'FFHQ']:
return create_model(**AFHQ_DICT)
elif dataset_name == 'IMAGENET':
return create_model(**IMAGENET_DICT)
else:
print('Not implemented.')
exit()