diff --git a/README.md b/README.md
index f8076d4f2d4be907354555d3469fa97bd98697d2..5833f05cd1f7d45e4fb34dca078577a12b783221 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
---
title: Prompt-Free Diffusion
emoji: 👀
-colorFrom: orange
+colorFrom: red
colorTo: blue
sdk: gradio
sdk_version: 3.32.0
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..94f62be11142199b82bb1c3d96cbae7ab1c6be50
--- /dev/null
+++ b/app.py
@@ -0,0 +1,494 @@
+################################################################################
+# Copyright (C) 2023 Xingqian Xu - All Rights Reserved #
+# #
+# Please visit Prompt-Free-Diffusion's arXiv paper for more details, link at #
+# arxiv.org/abs/2305.16223 #
+# #
+################################################################################
+
+import gradio as gr
+import os.path as osp
+from PIL import Image
+import numpy as np
+import time
+
+import torch
+import torchvision.transforms as tvtrans
+from lib.cfg_helper import model_cfg_bank
+from lib.model_zoo import get_model
+
+from collections import OrderedDict
+from lib.model_zoo.ddim import DDIMSampler
+
+n_sample_image = 1
+
+controlnet_path = OrderedDict([
+ ['canny' , ('canny' , 'pretrained/controlnet/control_sd15_canny_slimmed.safetensors')],
+ ['canny_v11p' , ('canny' , 'pretrained/controlnet/control_v11p_sd15_canny_slimmed.safetensors')],
+ ['depth' , ('depth' , 'pretrained/controlnet/control_sd15_depth_slimmed.safetensors')],
+ ['hed' , ('hed' , 'pretrained/controlnet/control_sd15_hed_slimmed.safetensors')],
+ ['mlsd' , ('mlsd' , 'pretrained/controlnet/control_sd15_mlsd_slimmed.safetensors')],
+ ['mlsd_v11p' , ('mlsd' , 'pretrained/controlnet/control_v11p_sd15_mlsd_slimmed.safetensors')],
+ ['normal' , ('normal' , 'pretrained/controlnet/control_sd15_normal_slimmed.safetensors')],
+ ['openpose' , ('openpose', 'pretrained/controlnet/control_sd15_openpose_slimmed.safetensors')],
+ ['openpose_v11p' , ('openpose', 'pretrained/controlnet/control_v11p_sd15_openpose_slimmed.safetensors')],
+ ['scribble' , ('scribble', 'pretrained/controlnet/control_sd15_scribble_slimmed.safetensors')],
+ ['softedge_v11p' , ('scribble', 'pretrained/controlnet/control_v11p_sd15_softedge_slimmed.safetensors')],
+ ['seg' , ('none' , 'pretrained/controlnet/control_sd15_seg_slimmed.safetensors')],
+ ['lineart_v11p' , ('none' , 'pretrained/controlnet/control_v11p_sd15_lineart_slimmed.safetensors')],
+ ['lineart_anime_v11p', ('none' , 'pretrained/controlnet/control_v11p_sd15s2_lineart_anime_slimmed.safetensors')],
+])
+
+preprocess_method = [
+ 'canny' ,
+ 'depth' ,
+ 'hed' ,
+ 'mlsd' ,
+ 'normal' ,
+ 'openpose' ,
+ 'openpose_withface' ,
+ 'openpose_withfacehand',
+ 'scribble' ,
+ 'none' ,
+]
+
+diffuser_path = OrderedDict([
+ ['SD-v1.5' , 'pretrained/pfd/diffuser/SD-v1-5.safetensors'],
+ ['OpenJouney-v4' , 'pretrained/pfd/diffuser/OpenJouney-v4.safetensors'],
+ ['Deliberate-v2.0' , 'pretrained/pfd/diffuser/Deliberate-v2-0.safetensors'],
+ ['RealisticVision-v2.0', 'pretrained/pfd/diffuser/RealisticVision-v2-0.safetensors'],
+ ['Anything-v4' , 'pretrained/pfd/diffuser/Anything-v4.safetensors'],
+ ['Oam-v3' , 'pretrained/pfd/diffuser/AbyssOrangeMix-v3.safetensors'],
+ ['Oam-v2' , 'pretrained/pfd/diffuser/AbyssOrangeMix-v2.safetensors'],
+])
+
+ctxencoder_path = OrderedDict([
+ ['SeeCoder' , 'pretrained/pfd/seecoder/seecoder-v1-0.safetensors'],
+ ['SeeCoder-PA' , 'pretrained/pfd/seecoder/seecoder-pa-v1-0.safetensors'],
+ ['SeeCoder-Anime', 'pretrained/pfd/seecoder/seecoder-anime-v1-0.safetensors'],
+])
+
+##########
+# helper #
+##########
+
+def highlight_print(info):
+ print('')
+ print(''.join(['#']*(len(info)+4)))
+ print('# '+info+' #')
+ print(''.join(['#']*(len(info)+4)))
+ print('')
+
+def load_sd_from_file(target):
+ if osp.splitext(target)[-1] == '.ckpt':
+ sd = torch.load(target, map_location='cpu')['state_dict']
+ elif osp.splitext(target)[-1] == '.pth':
+ sd = torch.load(target, map_location='cpu')
+ elif osp.splitext(target)[-1] == '.safetensors':
+ from safetensors.torch import load_file as stload
+ sd = OrderedDict(stload(target, device='cpu'))
+ else:
+ assert False, "File type must be .ckpt or .pth or .safetensors"
+ return sd
+
+########
+# main #
+########
+
+class prompt_free_diffusion(object):
+ def __init__(self,
+ fp16=False,
+ tag_ctx=None,
+ tag_diffuser=None,
+ tag_ctl=None,):
+
+ self.tag_ctx = tag_ctx
+ self.tag_diffuser = tag_diffuser
+ self.tag_ctl = tag_ctl
+ self.strict_sd = True
+
+ cfgm = model_cfg_bank()('pfd_seecoder_with_controlnet')
+ self.net = get_model()(cfgm)
+
+ self.action_load_ctx(tag_ctx)
+ self.action_load_diffuser(tag_diffuser)
+ self.action_load_ctl(tag_ctl)
+
+ if fp16:
+ highlight_print('Running in FP16')
+ self.net.ctx['image'].fp16 = True
+ self.net = self.net.half()
+ self.dtype = torch.float16
+ else:
+ self.dtype = torch.float32
+
+ self.use_cuda = torch.cuda.is_available()
+ if self.use_cuda:
+ self.net.to('cuda')
+
+ self.net.eval()
+ self.sampler = DDIMSampler(self.net)
+
+ self.n_sample_image = n_sample_image
+ self.ddim_steps = 50
+ self.ddim_eta = 0.0
+ self.image_latent_dim = 4
+
+ def load_ctx(self, pretrained):
+ sd = load_sd_from_file(pretrained)
+ sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \
+ if ki.find('ctx.')!=0]
+ sd.update(OrderedDict(sd_extra))
+
+ self.net.load_state_dict(sd, strict=True)
+ print('Load context encoder from [{}] strict [{}].'.format(pretrained, True))
+
+ def load_diffuser(self, pretrained):
+ sd = load_sd_from_file(pretrained)
+ if len([ki for ki in sd.keys() if ki.find('diffuser.image.context_blocks.')==0]) == 0:
+ sd = [(
+ ki.replace('diffuser.text.context_blocks.', 'diffuser.image.context_blocks.'), vi)
+ for ki, vi in sd.items()]
+ sd = OrderedDict(sd)
+ sd_extra = [(ki, vi) for ki, vi in self.net.state_dict().items() \
+ if ki.find('diffuser.')!=0]
+ sd.update(OrderedDict(sd_extra))
+ self.net.load_state_dict(sd, strict=True)
+ print('Load diffuser from [{}] strict [{}].'.format(pretrained, True))
+
+ def load_ctl(self, pretrained):
+ sd = load_sd_from_file(pretrained)
+ self.net.ctl.load_state_dict(sd, strict=True)
+ print('Load controlnet from [{}] strict [{}].'.format(pretrained, True))
+
+ def action_load_ctx(self, tag):
+ pretrained = ctxencoder_path[tag]
+ if tag == 'SeeCoder-PA':
+ from lib.model_zoo.seecoder import PPE_MLP
+ pe_layer = \
+ PPE_MLP(freq_num=20, freq_max=None, out_channel=768, mlp_layer=3)
+ if self.dtype == torch.float16:
+ pe_layer = pe_layer.half()
+ if self.use_cuda:
+ pe_layer.to('cuda')
+ pe_layer.eval()
+ self.net.ctx['image'].qtransformer.pe_layer = pe_layer
+ else:
+ self.net.ctx['image'].qtransformer.pe_layer = None
+ if pretrained is not None:
+ self.load_ctx(pretrained)
+ self.tag_ctx = tag
+ return tag
+
+ def action_load_diffuser(self, tag):
+ pretrained = diffuser_path[tag]
+ if pretrained is not None:
+ self.load_diffuser(pretrained)
+ self.tag_diffuser = tag
+ return tag
+
+ def action_load_ctl(self, tag):
+ pretrained = controlnet_path[tag][1]
+ if pretrained is not None:
+ self.load_ctl(pretrained)
+ self.tag_ctl = tag
+ return tag
+
+ def action_autoset_hw(self, imctl):
+ if imctl is None:
+ return 512, 512
+ w, h = imctl.size
+ w = w//64 * 64
+ h = h//64 * 64
+ w = w if w >=512 else 512
+ w = w if w <=1536 else 1536
+ h = h if h >=512 else 512
+ h = h if h <=1536 else 1536
+ return h, w
+
+ def action_autoset_method(self, tag):
+ return controlnet_path[tag][0]
+
+ def action_inference(
+ self, im, imctl, ctl_method, do_preprocess,
+ h, w, ugscale, seed,
+ tag_ctx, tag_diffuser, tag_ctl,):
+
+ if tag_ctx != self.tag_ctx:
+ self.action_load_ctx(tag_ctx)
+ if tag_diffuser != self.tag_diffuser:
+ self.action_load_diffuser(tag_diffuser)
+ if tag_ctl != self.tag_ctl:
+ self.action_load_ctl(tag_ctl)
+
+ n_samples = self.n_sample_image
+
+ sampler = self.sampler
+ device = self.net.device
+
+ w = w//64 * 64
+ h = h//64 * 64
+ if imctl is not None:
+ imctl = imctl.resize([w, h], Image.Resampling.BICUBIC)
+
+ craw = tvtrans.ToTensor()(im)[None].to(device).to(self.dtype)
+ c = self.net.ctx_encode(craw, which='image').repeat(n_samples, 1, 1)
+ u = torch.zeros_like(c)
+
+ if tag_ctx in ["SeeCoder-Anime"]:
+ u = torch.load('assets/anime_ug.pth')[None].to(device).to(self.dtype)
+ pad = c.size(1) - u.size(1)
+ u = torch.cat([u, torch.zeros_like(u[:, 0:1].repeat(1, pad, 1))], axis=1)
+
+ if tag_ctl != 'none':
+ ccraw = tvtrans.ToTensor()(imctl)[None].to(device).to(self.dtype)
+ if do_preprocess:
+ cc = self.net.ctl.preprocess(ccraw, type=ctl_method, size=[h, w])
+ cc = cc.to(self.dtype)
+ else:
+ cc = ccraw
+ else:
+ cc = None
+
+ shape = [n_samples, self.image_latent_dim, h//8, w//8]
+
+ if seed < 0:
+ np.random.seed(int(time.time()))
+ torch.manual_seed(-seed + 100)
+ else:
+ np.random.seed(seed + 100)
+ torch.manual_seed(seed)
+
+ x, _ = sampler.sample(
+ steps=self.ddim_steps,
+ x_info={'type':'image',},
+ c_info={'type':'image', 'conditioning':c, 'unconditional_conditioning':u,
+ 'unconditional_guidance_scale':ugscale,
+ 'control':cc,},
+ shape=shape,
+ verbose=False,
+ eta=self.ddim_eta)
+
+ ccout = [tvtrans.ToPILImage()(i) for i in cc] if cc is not None else []
+ imout = self.net.vae_decode(x, which='image')
+ imout = [tvtrans.ToPILImage()(i) for i in imout]
+ return imout + ccout
+
+pfd_inference = prompt_free_diffusion(
+ fp16=True, tag_ctx = 'SeeCoder', tag_diffuser = 'Deliberate-v2.0', tag_ctl = 'canny',)
+
+#################
+# sub interface #
+#################
+
+cache_examples = True
+
+def get_example():
+ case = [
+ [
+ 'assets/examples/ghibli-input.jpg',
+ 'assets/examples/ghibli-canny.png',
+ 'canny', False,
+ 768, 1024, 1.8, 23,
+ 'SeeCoder', 'Deliberate-v2.0', 'canny', ],
+ [
+ 'assets/examples/astronautridinghouse-input.jpg',
+ 'assets/examples/astronautridinghouse-canny.png',
+ 'canny', False,
+ 512, 768, 2.0, 21,
+ 'SeeCoder', 'Deliberate-v2.0', 'canny', ],
+ [
+ 'assets/examples/grassland-input.jpg',
+ 'assets/examples/grassland-scribble.png',
+ 'scribble', False,
+ 768, 512, 2.0, 41,
+ 'SeeCoder', 'Deliberate-v2.0', 'scribble', ],
+ [
+ 'assets/examples/jeep-input.jpg',
+ 'assets/examples/jeep-depth.png',
+ 'depth', False,
+ 512, 768, 2.0, 30,
+ 'SeeCoder', 'Deliberate-v2.0', 'depth', ],
+ [
+ 'assets/examples/bedroom-input.jpg',
+ 'assets/examples/bedroom-mlsd.png',
+ 'mlsd', False,
+ 512, 512, 2.0, 31,
+ 'SeeCoder', 'Deliberate-v2.0', 'mlsd', ],
+ [
+ 'assets/examples/nightstreet-input.jpg',
+ 'assets/examples/nightstreet-canny.png',
+ 'canny', False,
+ 768, 512, 2.3, 20,
+ 'SeeCoder', 'Deliberate-v2.0', 'canny', ],
+ [
+ 'assets/examples/woodcar-input.jpg',
+ 'assets/examples/woodcar-depth.png',
+ 'depth', False,
+ 768, 512, 2.0, 20,
+ 'SeeCoder', 'Deliberate-v2.0', 'depth', ],
+ [
+ 'assets/examples-anime/miku.jpg',
+ 'assets/examples-anime/miku-canny.png',
+ 'canny', False,
+ 768, 576, 1.5, 22,
+ 'SeeCoder-Anime', 'Anything-v4', 'canny', ],
+ [
+ 'assets/examples-anime/random0.jpg',
+ 'assets/examples-anime/pose.png',
+ 'openpose', False,
+ 768, 1536, 2.0, 41,
+ 'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ],
+ [
+ 'assets/examples-anime/random1.jpg',
+ 'assets/examples-anime/pose.png',
+ 'openpose', False,
+ 768, 1536, 2.5, 28,
+ 'SeeCoder-Anime', 'Oam-v2', 'openpose_v11p', ],
+ [
+ 'assets/examples-anime/camping.jpg',
+ 'assets/examples-anime/pose.png',
+ 'openpose', False,
+ 768, 1536, 2.0, 35,
+ 'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ],
+ [
+ 'assets/examples-anime/hanfu_girl.jpg',
+ 'assets/examples-anime/pose.png',
+ 'openpose', False,
+ 768, 1536, 2.0, 20,
+ 'SeeCoder-Anime', 'Anything-v4', 'openpose_v11p', ],
+ ]
+ return case
+
+def interface():
+ with gr.Row():
+ with gr.Column():
+ img_input = gr.Image(label='Image Input', type='pil', elem_id='customized_imbox')
+ with gr.Row():
+ out_width = gr.Slider(label="Width" , minimum=512, maximum=1536, value=512, step=64, visible=True)
+ out_height = gr.Slider(label="Height", minimum=512, maximum=1536, value=512, step=64, visible=True)
+ with gr.Row():
+ scl_lvl = gr.Slider(label="CFGScale", minimum=0, maximum=10, value=2, step=0.01, visible=True)
+ seed = gr.Number(20, label="Seed", precision=0)
+ with gr.Row():
+ tag_ctx = gr.Dropdown(label='Context Encoder', choices=[pi for pi in ctxencoder_path.keys()], value='SeeCoder')
+ tag_diffuser = gr.Dropdown(label='Diffuser', choices=[pi for pi in diffuser_path.keys()], value='Deliberate-v2.0')
+ button = gr.Button("Run")
+ with gr.Column():
+ ctl_input = gr.Image(label='Control Input', type='pil', elem_id='customized_imbox')
+ do_preprocess = gr.Checkbox(label='Preprocess', value=False)
+ with gr.Row():
+ ctl_method = gr.Dropdown(label='Preprocess Type', choices=preprocess_method, value='canny')
+ tag_ctl = gr.Dropdown(label='ControlNet', choices=[pi for pi in controlnet_path.keys()], value='canny')
+ with gr.Column():
+ img_output = gr.Gallery(label="Image Result", elem_id='customized_imbox').style(grid=n_sample_image+1)
+
+ tag_ctl.change(
+ pfd_inference.action_autoset_method,
+ inputs = [tag_ctl],
+ outputs = [ctl_method],)
+
+ ctl_input.change(
+ pfd_inference.action_autoset_hw,
+ inputs = [ctl_input],
+ outputs = [out_height, out_width],)
+
+ # tag_ctx.change(
+ # pfd_inference.action_load_ctx,
+ # inputs = [tag_ctx],
+ # outputs = [tag_ctx],)
+
+ # tag_diffuser.change(
+ # pfd_inference.action_load_diffuser,
+ # inputs = [tag_diffuser],
+ # outputs = [tag_diffuser],)
+
+ # tag_ctl.change(
+ # pfd_inference.action_load_ctl,
+ # inputs = [tag_ctl],
+ # outputs = [tag_ctl],)
+
+ button.click(
+ pfd_inference.action_inference,
+ inputs=[img_input, ctl_input, ctl_method, do_preprocess,
+ out_height, out_width, scl_lvl, seed,
+ tag_ctx, tag_diffuser, tag_ctl, ],
+ outputs=[img_output])
+
+ gr.Examples(
+ label='Examples',
+ examples=get_example(),
+ fn=pfd_inference.action_inference,
+ inputs=[img_input, ctl_input, ctl_method, do_preprocess,
+ out_height, out_width, scl_lvl, seed,
+ tag_ctx, tag_diffuser, tag_ctl, ],
+ outputs=[img_output],
+ cache_examples=cache_examples,)
+
+#############
+# Interface #
+#############
+
+css = """
+ #customized_imbox {
+ min-height: 450px;
+ }
+ #customized_imbox>div[data-testid="image"] {
+ min-height: 450px;
+ }
+ #customized_imbox>div[data-testid="image"]>div {
+ min-height: 450px;
+ }
+ #customized_imbox>div[data-testid="image"]>iframe {
+ min-height: 450px;
+ }
+ #customized_imbox>div.unpadded_box {
+ min-height: 450px;
+ }
+ #myinst {
+ font-size: 0.8rem;
+ margin: 0rem;
+ color: #6B7280;
+ }
+ #maskinst {
+ text-align: justify;
+ min-width: 1200px;
+ }
+ #maskinst>img {
+ min-width:399px;
+ max-width:450px;
+ vertical-align: top;
+ display: inline-block;
+ }
+ #maskinst:after {
+ content: "";
+ width: 100%;
+ display: inline-block;
+ }
+"""
+
+if True:
+ with gr.Blocks(css=css) as demo:
+ gr.HTML(
+ """
+
+
+ Prompt-Free Diffusion
+
+
+ """)
+
+ interface()
+
+ # gr.HTML(
+ # """
+ #
+ #
+ # Version: {}
+ #
+ #
+ # """.format(' '+str(pfd_inference.pretrained)))
+
+ # demo.launch(server_name="0.0.0.0", server_port=7992)
+ # demo.launch()
+ demo.launch(debug=True)
diff --git a/configs/model/autokl.yaml b/configs/model/autokl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..40eb51cd75a3cf17bbb32117eff6c4e287621024
--- /dev/null
+++ b/configs/model/autokl.yaml
@@ -0,0 +1,26 @@
+autokl:
+ symbol: autokl
+ find_unused_parameters: false
+
+autokl_v1:
+ super_cfg: autokl
+ type: autoencoderkl
+ args:
+ embed_dim: 4
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig: null
+ pth: pretrained/kl-f8.pth
+
+autokl_v2:
+ super_cfg: autokl_v1
+ pth: pretrained/pfd/vae/sd-v2-0-base-autokl.pth
diff --git a/configs/model/clip.yaml b/configs/model/clip.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..76010081b942631a8ab23b055f71e33d562777af
--- /dev/null
+++ b/configs/model/clip.yaml
@@ -0,0 +1,12 @@
+################
+# clip for sd1 #
+################
+
+clip:
+ symbol: clip
+ args: {}
+
+clip_text_context_encoder_sdv1:
+ super_cfg: clip
+ type: clip_text_context_encoder_sdv1
+ args: {}
diff --git a/configs/model/controlnet.yaml b/configs/model/controlnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..35d6aa85e58c8f8da764479729327cae2d8ac038
--- /dev/null
+++ b/configs/model/controlnet.yaml
@@ -0,0 +1,18 @@
+controlnet:
+ symbol: controlnet
+ type: controlnet
+ find_unused_parameters: false
+ args:
+ image_size: 32 # unused
+ in_channels: 4
+ hint_channels: 3
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
diff --git a/configs/model/openai_unet.yaml b/configs/model/openai_unet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..69e22b594b84cb66d8c574fdfe7f00daf6e81d4b
--- /dev/null
+++ b/configs/model/openai_unet.yaml
@@ -0,0 +1,35 @@
+openai_unet_sd:
+ type: openai_unet
+ args:
+ image_size: null # no use
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: [ 2, 2, 2, 2 ]
+ channel_mult: [ 1, 2, 4, 4 ]
+ # disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+#########
+# v1 2d #
+#########
+
+openai_unet_2d_v1:
+ type: openai_unet_2d_next
+ args:
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: [ 2, 2, 2, 2 ]
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ context_dim: 768
+ use_checkpoint: False
+ parts: [global, data, context]
diff --git a/configs/model/pfd.yaml b/configs/model/pfd.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1db7009b048505f9da3ef280dbec6abc0e0c2123
--- /dev/null
+++ b/configs/model/pfd.yaml
@@ -0,0 +1,33 @@
+pfd_base:
+ symbol: pfd
+ find_unused_parameters: true
+ type: pfd
+ args:
+ beta_linear_start: 0.00085
+ beta_linear_end: 0.012
+ timesteps: 1000
+ use_ema: false
+
+pfd_seecoder:
+ super_cfg: pfd_base
+ args:
+ vae_cfg_list:
+ - [image, MODEL(autokl_v2)]
+ ctx_cfg_list:
+ - [image, MODEL(seecoder)]
+ diffuser_cfg_list:
+ - [image, MODEL(openai_unet_2d_v1)]
+ latent_scale_factor:
+ image: 0.18215
+
+pdf_seecoder_pa:
+ super_cfg: pfd_seecoder
+ args:
+ ctx_cfg_list:
+ - [image, MODEL(seecoder_pa)]
+
+pfd_seecoder_with_controlnet:
+ super_cfg: pfd_seecoder
+ type: pfd_with_control
+ args:
+ ctl_cfg: MODEL(controlnet)
diff --git a/configs/model/seecoder.yaml b/configs/model/seecoder.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1b80834e5860257d5544aad4af93cb8aebae656c
--- /dev/null
+++ b/configs/model/seecoder.yaml
@@ -0,0 +1,62 @@
+seecoder_base:
+ symbol: seecoder
+ args: {}
+
+seecoder:
+ super_cfg: seecoder_base
+ type: seecoder
+ args:
+ imencoder_cfg : MODEL(swin_large)
+ imdecoder_cfg : MODEL(seecoder_decoder)
+ qtransformer_cfg : MODEL(seecoder_query_transformer)
+
+seecoder_pa:
+ super_cfg: seet
+ type: seecoder
+ args:
+ imencoder_cfg : MODEL(swin_large)
+ imdecoder_cfg : MODEL(seecoder_decoder)
+ qtransformer_cfg : MODEL(seecoder_query_transformer_position_aware)
+
+###########
+# decoder #
+###########
+
+seecoder_decoder:
+ super_cfg: seecoder_base
+ type: seecoder_decoder
+ args:
+ inchannels:
+ res3: 384
+ res4: 768
+ res5: 1536
+ trans_input_tags: ['res3', 'res4', 'res5']
+ trans_dim: 768
+ trans_dropout: 0.1
+ trans_nheads: 8
+ trans_feedforward_dim: 1024
+ trans_num_layers: 6
+
+#####################
+# query_transformer #
+#####################
+
+seecoder_query_transformer:
+ super_cfg: seecoder_base
+ type: seecoder_query_transformer
+ args:
+ in_channels : 768
+ hidden_dim: 768
+ num_queries: [4, 144]
+ nheads: 8
+ num_layers: 9
+ feedforward_dim: 2048
+ pre_norm: False
+ num_feature_levels: 3
+ enforce_input_project: False
+ with_fea2d_pos: false
+
+seecoder_query_transformer_position_aware:
+ super_cfg: seecoder_query_transformer
+ args:
+ with_fea2d_pos: true
diff --git a/configs/model/swin.yaml b/configs/model/swin.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..23cc1c08215bdc28dae688d9510a3befa49c3604
--- /dev/null
+++ b/configs/model/swin.yaml
@@ -0,0 +1,32 @@
+swin:
+ symbol: swin
+ args: {}
+
+swin_base:
+ super_cfg: swin
+ type: swin
+ args:
+ embed_dim: 128
+ depths: [ 2, 2, 18, 2 ]
+ num_heads: [ 4, 8, 16, 32 ]
+ window_size: 7
+ ape: False
+ drop_path_rate: 0.3
+ patch_norm: True
+ pretrained: pretrained/swin/swin_base_patch4_window7_224_22k.pth
+ strict_sd: False
+
+swin_large:
+ super_cfg: swin
+ type: swin
+ args:
+ embed_dim: 192
+ depths: [ 2, 2, 18, 2 ]
+ num_heads: [ 6, 12, 24, 48 ]
+ window_size: 12
+ ape: False
+ drop_path_rate: 0.3
+ patch_norm: True
+ pretrained: pretrained/swin/swin_large_patch4_window12_384_22k.pth
+ strict_sd: False
+
diff --git a/lib/__init__.py b/lib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/__pycache__/__init__.cpython-310.pyc b/lib/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..211502528ca75f215d29a990d3a4dd993b31c388
Binary files /dev/null and b/lib/__pycache__/__init__.cpython-310.pyc differ
diff --git a/lib/__pycache__/cfg_helper.cpython-310.pyc b/lib/__pycache__/cfg_helper.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf6fa4df3843273e68efe3221534eaddb7c653c8
Binary files /dev/null and b/lib/__pycache__/cfg_helper.cpython-310.pyc differ
diff --git a/lib/__pycache__/cfg_holder.cpython-310.pyc b/lib/__pycache__/cfg_holder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5385322a21a5e653bf0e3845c31c18c6188da9ed
Binary files /dev/null and b/lib/__pycache__/cfg_holder.cpython-310.pyc differ
diff --git a/lib/__pycache__/log_service.cpython-310.pyc b/lib/__pycache__/log_service.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..714a94f852f21eed2a1279953dd32b8808b5b3b9
Binary files /dev/null and b/lib/__pycache__/log_service.cpython-310.pyc differ
diff --git a/lib/__pycache__/sync.cpython-310.pyc b/lib/__pycache__/sync.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3680bdef8838aa0c349bc3375a5a329efee73f6a
Binary files /dev/null and b/lib/__pycache__/sync.cpython-310.pyc differ
diff --git a/lib/cfg_helper.py b/lib/cfg_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..e549e35d5be238a8e73eb65ff8625dc2838ab230
--- /dev/null
+++ b/lib/cfg_helper.py
@@ -0,0 +1,666 @@
+import os
+import os.path as osp
+import shutil
+import copy
+import time
+import pprint
+import numpy as np
+import torch
+import matplotlib
+import argparse
+import json
+import yaml
+from easydict import EasyDict as edict
+
+from .model_zoo import get_model
+
+############
+# cfg_bank #
+############
+
+def cfg_solvef(cmd, root):
+ if not isinstance(cmd, str):
+ return cmd
+
+ if cmd.find('SAME')==0:
+ zoom = root
+ p = cmd[len('SAME'):].strip('()').split('.')
+ p = [pi.strip() for pi in p]
+ for pi in p:
+ try:
+ pi = int(pi)
+ except:
+ pass
+
+ try:
+ zoom = zoom[pi]
+ except:
+ return cmd
+ return cfg_solvef(zoom, root)
+
+ if cmd.find('SEARCH')==0:
+ zoom = root
+ p = cmd[len('SEARCH'):].strip('()').split('.')
+ p = [pi.strip() for pi in p]
+ find = True
+ # Depth first search
+ for pi in p:
+ try:
+ pi = int(pi)
+ except:
+ pass
+
+ try:
+ zoom = zoom[pi]
+ except:
+ find = False
+ break
+
+ if find:
+ return cfg_solvef(zoom, root)
+ else:
+ if isinstance(root, dict):
+ for ri in root:
+ rv = cfg_solvef(cmd, root[ri])
+ if rv != cmd:
+ return rv
+ if isinstance(root, list):
+ for ri in root:
+ rv = cfg_solvef(cmd, ri)
+ if rv != cmd:
+ return rv
+ return cmd
+
+ if cmd.find('MODEL')==0:
+ goto = cmd[len('MODEL'):].strip('()')
+ return model_cfg_bank()(goto)
+
+ if cmd.find('DATASET')==0:
+ goto = cmd[len('DATASET'):].strip('()')
+ return dataset_cfg_bank()(goto)
+
+ return cmd
+
+def cfg_solve(cfg, cfg_root):
+ # The function solve cfg element such that
+ # all sorrogate input are settled.
+ # (i.e. SAME(***) )
+ if isinstance(cfg, list):
+ for i in range(len(cfg)):
+ if isinstance(cfg[i], (list, dict)):
+ cfg[i] = cfg_solve(cfg[i], cfg_root)
+ else:
+ cfg[i] = cfg_solvef(cfg[i], cfg_root)
+ if isinstance(cfg, dict):
+ for k in cfg:
+ if isinstance(cfg[k], (list, dict)):
+ cfg[k] = cfg_solve(cfg[k], cfg_root)
+ else:
+ cfg[k] = cfg_solvef(cfg[k], cfg_root)
+ return cfg
+
+class model_cfg_bank(object):
+ def __init__(self):
+ self.cfg_dir = osp.join('configs', 'model')
+ self.cfg_bank = edict()
+
+ def __call__(self, name):
+ if name not in self.cfg_bank:
+ cfg_path = self.get_yaml_path(name)
+ with open(cfg_path, 'r') as f:
+ cfg_new = yaml.load(
+ f, Loader=yaml.FullLoader)
+ cfg_new = edict(cfg_new)
+ self.cfg_bank.update(cfg_new)
+
+ cfg = self.cfg_bank[name]
+ cfg.name = name
+ if 'super_cfg' not in cfg:
+ cfg = cfg_solve(cfg, cfg)
+ self.cfg_bank[name] = cfg
+ return copy.deepcopy(cfg)
+
+ super_cfg = self.__call__(cfg.super_cfg)
+ # unlike other field,
+ # args will not be replaced but update.
+ if 'args' in cfg:
+ if 'args' in super_cfg:
+ super_cfg.args.update(cfg.args)
+ else:
+ super_cfg.args = cfg.args
+ cfg.pop('args')
+
+ super_cfg.update(cfg)
+ super_cfg.pop('super_cfg')
+ cfg = super_cfg
+ try:
+ delete_args = cfg.pop('delete_args')
+ except:
+ delete_args = []
+
+ for dargs in delete_args:
+ cfg.args.pop(dargs)
+
+ cfg = cfg_solve(cfg, cfg)
+ self.cfg_bank[name] = cfg
+ return copy.deepcopy(cfg)
+
+ def get_yaml_path(self, name):
+ if name.find('openai_unet')==0:
+ return osp.join(
+ self.cfg_dir, 'openai_unet.yaml')
+ elif name.find('clip')==0:
+ return osp.join(
+ self.cfg_dir, 'clip.yaml')
+ elif name.find('autokl')==0:
+ return osp.join(
+ self.cfg_dir, 'autokl.yaml')
+ elif name.find('controlnet')==0:
+ return osp.join(
+ self.cfg_dir, 'controlnet.yaml')
+ elif name.find('swin')==0:
+ return osp.join(
+ self.cfg_dir, 'swin.yaml')
+ elif name.find('pfd')==0:
+ return osp.join(
+ self.cfg_dir, 'pfd.yaml')
+ elif name.find('seecoder')==0:
+ return osp.join(
+ self.cfg_dir, 'seecoder.yaml')
+ else:
+ raise ValueError
+
+class dataset_cfg_bank(object):
+ def __init__(self):
+ self.cfg_dir = osp.join('configs', 'dataset')
+ self.cfg_bank = edict()
+
+ def __call__(self, name):
+ if name not in self.cfg_bank:
+ cfg_path = self.get_yaml_path(name)
+ with open(cfg_path, 'r') as f:
+ cfg_new = yaml.load(
+ f, Loader=yaml.FullLoader)
+ cfg_new = edict(cfg_new)
+ self.cfg_bank.update(cfg_new)
+
+ cfg = self.cfg_bank[name]
+ cfg.name = name
+ if cfg.get('super_cfg', None) is None:
+ cfg = cfg_solve(cfg, cfg)
+ self.cfg_bank[name] = cfg
+ return copy.deepcopy(cfg)
+
+ super_cfg = self.__call__(cfg.super_cfg)
+ super_cfg.update(cfg)
+ cfg = super_cfg
+ cfg.super_cfg = None
+ try:
+ delete = cfg.pop('delete')
+ except:
+ delete = []
+
+ for dargs in delete:
+ cfg.pop(dargs)
+
+ cfg = cfg_solve(cfg, cfg)
+ self.cfg_bank[name] = cfg
+ return copy.deepcopy(cfg)
+
+ def get_yaml_path(self, name):
+ if name.find('cityscapes')==0:
+ return osp.join(
+ self.cfg_dir, 'cityscapes.yaml')
+ elif name.find('div2k')==0:
+ return osp.join(
+ self.cfg_dir, 'div2k.yaml')
+ elif name.find('gandiv2k')==0:
+ return osp.join(
+ self.cfg_dir, 'gandiv2k.yaml')
+ elif name.find('srbenchmark')==0:
+ return osp.join(
+ self.cfg_dir, 'srbenchmark.yaml')
+ elif name.find('imagedir')==0:
+ return osp.join(
+ self.cfg_dir, 'imagedir.yaml')
+ elif name.find('places2')==0:
+ return osp.join(
+ self.cfg_dir, 'places2.yaml')
+ elif name.find('ffhq')==0:
+ return osp.join(
+ self.cfg_dir, 'ffhq.yaml')
+ elif name.find('imcpt')==0:
+ return osp.join(
+ self.cfg_dir, 'imcpt.yaml')
+ elif name.find('texture')==0:
+ return osp.join(
+ self.cfg_dir, 'texture.yaml')
+ elif name.find('openimages')==0:
+ return osp.join(
+ self.cfg_dir, 'openimages.yaml')
+ elif name.find('laion2b')==0:
+ return osp.join(
+ self.cfg_dir, 'laion2b.yaml')
+ elif name.find('laionart')==0:
+ return osp.join(
+ self.cfg_dir, 'laionart.yaml')
+ elif name.find('celeba')==0:
+ return osp.join(
+ self.cfg_dir, 'celeba.yaml')
+ elif name.find('coyo')==0:
+ return osp.join(
+ self.cfg_dir, 'coyo.yaml')
+ elif name.find('pafc')==0:
+ return osp.join(
+ self.cfg_dir, 'pafc.yaml')
+ elif name.find('coco')==0:
+ return osp.join(
+ self.cfg_dir, 'coco.yaml')
+ elif name.find('genai')==0:
+ return osp.join(
+ self.cfg_dir, 'genai.yaml')
+ else:
+ raise ValueError
+
+class experiment_cfg_bank(object):
+ def __init__(self):
+ self.cfg_dir = osp.join('configs', 'experiment')
+ self.cfg_bank = edict()
+
+ def __call__(self, name):
+ if name not in self.cfg_bank:
+ cfg_path = self.get_yaml_path(name)
+ with open(cfg_path, 'r') as f:
+ cfg = yaml.load(
+ f, Loader=yaml.FullLoader)
+ cfg = edict(cfg)
+
+ cfg = cfg_solve(cfg, cfg)
+ cfg = cfg_solve(cfg, cfg)
+ # twice for SEARCH
+ self.cfg_bank[name] = cfg
+ return copy.deepcopy(cfg)
+
+ def get_yaml_path(self, name):
+ return osp.join(
+ self.cfg_dir, name+'.yaml')
+
+def load_cfg_yaml(path):
+ if osp.isfile(path):
+ cfg_path = path
+ elif osp.isfile(osp.join('configs', 'experiment', path)):
+ cfg_path = osp.join('configs', 'experiment', path)
+ elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')):
+ cfg_path = osp.join('configs', 'experiment', path+'.yaml')
+ else:
+ assert False, 'No such config!'
+
+ with open(cfg_path, 'r') as f:
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
+ cfg = edict(cfg)
+ cfg = cfg_solve(cfg, cfg)
+ cfg = cfg_solve(cfg, cfg)
+ return cfg
+
+##############
+# cfg_helper #
+##############
+
+def get_experiment_id(ref=None):
+ if ref is None:
+ time.sleep(0.5)
+ return int(time.time()*100)
+ else:
+ try:
+ return int(ref)
+ except:
+ pass
+
+ _, ref = osp.split(ref)
+ ref = ref.split('_')[0]
+ try:
+ return int(ref)
+ except:
+ assert False, 'Invalid experiment ID!'
+
+def record_resume_cfg(path):
+ cnt = 0
+ while True:
+ if osp.exists(path+'.{:04d}'.format(cnt)):
+ cnt += 1
+ continue
+ shutil.copyfile(path, path+'.{:04d}'.format(cnt))
+ break
+
+def get_command_line_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--debug', action='store_true', default=False)
+ parser.add_argument('--config', type=str)
+ parser.add_argument('--gpu', nargs='+', type=int)
+
+ parser.add_argument('--node_rank', type=int)
+ parser.add_argument('--node_list', nargs='+', type=str)
+ parser.add_argument('--nodes', type=int)
+ parser.add_argument('--addr', type=str, default='127.0.0.1')
+ parser.add_argument('--port', type=int, default=11233)
+
+ parser.add_argument('--signature', nargs='+', type=str)
+ parser.add_argument('--seed', type=int)
+
+ parser.add_argument('--eval', type=str)
+ parser.add_argument('--eval_subdir', type=str)
+ parser.add_argument('--pretrained', type=str)
+
+ parser.add_argument('--resume_dir', type=str)
+ parser.add_argument('--resume_step', type=int)
+ parser.add_argument('--resume_weight', type=str)
+
+ args = parser.parse_args()
+
+ # Special handling the resume
+ if args.resume_dir is not None:
+ cfg = edict()
+ cfg.env = edict()
+ cfg.env.debug = args.debug
+ cfg.env.resume = edict()
+ cfg.env.resume.dir = args.resume_dir
+ cfg.env.resume.step = args.resume_step
+ cfg.env.resume.weight = args.resume_weight
+ return cfg
+
+ cfg = load_cfg_yaml(args.config)
+ cfg.env.debug = args.debug
+ cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu)
+ cfg.env.master_addr = args.addr
+ cfg.env.master_port = args.port
+ cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port)
+
+ if args.node_list is None:
+ cfg.env.node_rank = 0 if args.node_rank is None else args.node_rank
+ cfg.env.nodes = 1 if args.nodes is None else args.nodes
+ else:
+ import socket
+ hostname = socket.gethostname()
+ assert cfg.env.master_addr == args.node_list[0]
+ cfg.env.node_rank = args.node_list.index(hostname)
+ cfg.env.nodes = len(args.node_list)
+ cfg.env.node_list = args.node_list
+
+ istrain = False if args.eval is not None else True
+ isdebug = cfg.env.debug
+
+ if istrain:
+ if isdebug:
+ cfg.env.experiment_id = 999999999999
+ cfg.train.signature = ['debug']
+ else:
+ cfg.env.experiment_id = get_experiment_id()
+ if args.signature is not None:
+ cfg.train.signature = args.signature
+ else:
+ if 'train' in cfg:
+ cfg.pop('train')
+ cfg.env.experiment_id = get_experiment_id(args.eval)
+ if args.signature is not None:
+ cfg.eval.signature = args.signature
+
+ if isdebug and (args.eval is None):
+ cfg.env.experiment_id = 999999999999
+ cfg.eval.signature = ['debug']
+
+ if args.eval_subdir is not None:
+ if isdebug:
+ cfg.eval.eval_subdir = 'debug'
+ else:
+ cfg.eval.eval_subdir = args.eval_subdir
+ if args.pretrained is not None:
+ cfg.eval.pretrained = args.pretrained
+ # The override pretrained over the setting in cfg.model
+
+ if args.seed is not None:
+ cfg.env.rnd_seed = args.seed
+
+ return cfg
+
+def cfg_initiates(cfg):
+ cfge = cfg.env
+ isdebug = cfge.debug
+ isresume = 'resume' in cfge
+ istrain = 'train' in cfg
+ haseval = 'eval' in cfg
+ cfgt = cfg.train if istrain else None
+ cfgv = cfg.eval if haseval else None
+
+ ###############################
+ # get some environment params #
+ ###############################
+
+ cfge.computer = os.uname()
+ cfge.torch_version = str(torch.__version__)
+
+ ##########
+ # resume #
+ ##########
+
+ if isresume:
+ resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml')
+ record_resume_cfg(resume_cfg_path)
+ with open(resume_cfg_path, 'r') as f:
+ cfg_resume = yaml.load(f, Loader=yaml.FullLoader)
+ cfg_resume = edict(cfg_resume)
+ cfg_resume.env.update(cfge)
+ cfg = cfg_resume
+ cfge = cfg.env
+ log_file = cfg.train.log_file
+
+ print('')
+ print('##########')
+ print('# resume #')
+ print('##########')
+ print('')
+ with open(log_file, 'a') as f:
+ print('', file=f)
+ print('##########', file=f)
+ print('# resume #', file=f)
+ print('##########', file=f)
+ print('', file=f)
+
+ pprint.pprint(cfg)
+ with open(log_file, 'a') as f:
+ pprint.pprint(cfg, f)
+
+ ####################
+ # node distributed #
+ ####################
+
+ if cfg.env.master_addr!='127.0.0.1':
+ os.environ['MASTER_ADDR'] = cfge.master_addr
+ os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port)
+ if cfg.env.dist_backend=='nccl':
+ os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET'
+ if cfg.env.dist_backend=='gloo':
+ os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET'
+
+ #######################
+ # cuda visible device #
+ #######################
+
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
+ [str(gid) for gid in cfge.gpu_device])
+
+ #####################
+ # return resume cfg #
+ #####################
+
+ if isresume:
+ return cfg
+
+ #############################################
+ # some misc setting that not need in resume #
+ #############################################
+
+ cfgm = cfg.model
+ cfge.gpu_count = len(cfge.gpu_device)
+
+ ##########################################
+ # align batch size and num worker config #
+ ##########################################
+
+ gpu_n = cfge.gpu_count * cfge.nodes
+ def align_batch_size(bs, bs_per_gpu):
+ assert (bs is not None) or (bs_per_gpu is not None)
+ bs = bs_per_gpu * gpu_n if bs is None else bs
+ bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu
+ assert (bs == bs_per_gpu * gpu_n)
+ return bs, bs_per_gpu
+
+ if istrain:
+ cfgt.batch_size, cfgt.batch_size_per_gpu = \
+ align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu)
+ cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \
+ align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu)
+ if haseval:
+ cfgv.batch_size, cfgv.batch_size_per_gpu = \
+ align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu)
+ cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \
+ align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu)
+
+ ##################
+ # create log dir #
+ ##################
+
+ if istrain:
+ if not isdebug:
+ sig = cfgt.get('signature', [])
+ sig = sig + ['s{}'.format(cfge.rnd_seed)]
+ else:
+ sig = ['debug']
+
+ log_dir = [
+ cfge.log_root_dir,
+ '{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol),
+ '_'.join([str(cfge.experiment_id)] + sig)
+ ]
+ log_dir = osp.join(*log_dir)
+ log_file = osp.join(log_dir, 'train.log')
+ if not osp.exists(log_file):
+ os.makedirs(osp.dirname(log_file))
+ cfgt.log_dir = log_dir
+ cfgt.log_file = log_file
+
+ if haseval:
+ cfgv.log_dir = log_dir
+ cfgv.log_file = log_file
+ else:
+ model_symbol = cfgm.symbol
+ if cfgv.get('dataset', None) is None:
+ dataset_symbol = 'nodataset'
+ else:
+ dataset_symbol = cfgv.dataset.symbol
+
+ log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol))
+ exp_dir = search_experiment_folder(log_dir, cfge.experiment_id)
+ if exp_dir is None:
+ if not isdebug:
+ sig = cfgv.get('signature', []) + ['evalonly']
+ else:
+ sig = ['debug']
+ exp_dir = '_'.join([str(cfge.experiment_id)] + sig)
+
+ eval_subdir = cfgv.get('eval_subdir', None)
+ # override subdir in debug mode (if eval_subdir is set)
+ eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir
+
+ if eval_subdir is not None:
+ log_dir = osp.join(log_dir, exp_dir, eval_subdir)
+ else:
+ log_dir = osp.join(log_dir, exp_dir)
+
+ disable_log_override = cfgv.get('disable_log_override', False)
+ if osp.isdir(log_dir):
+ if disable_log_override:
+ assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir)
+ else:
+ os.makedirs(log_dir)
+
+ log_file = osp.join(log_dir, 'eval.log')
+ cfgv.log_dir = log_dir
+ cfgv.log_file = log_file
+
+ ######################
+ # print and save cfg #
+ ######################
+
+ pprint.pprint(cfg)
+ if cfge.node_rank==0:
+ with open(log_file, 'w') as f:
+ pprint.pprint(cfg, f)
+ with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
+ yaml.dump(edict_2_dict(cfg), f)
+ else:
+ with open(osp.join(log_dir, 'config.yaml.{}'.format(cfge.node_rank)), 'w') as f:
+ yaml.dump(edict_2_dict(cfg), f)
+
+ #############
+ # save code #
+ #############
+
+ save_code = False
+ if istrain:
+ save_code = cfgt.get('save_code', False)
+ elif haseval:
+ save_code = cfgv.get('save_code', False)
+ save_code = save_code and (cfge.node_rank==0)
+
+ if save_code:
+ codedir = osp.join(log_dir, 'code')
+ if osp.exists(codedir):
+ shutil.rmtree(codedir)
+ for d in ['configs', 'lib']:
+ fromcodedir = d
+ tocodedir = osp.join(codedir, d)
+ shutil.copytree(
+ fromcodedir, tocodedir,
+ ignore=shutil.ignore_patterns(
+ '*__pycache__*', '*build*'))
+ for codei in os.listdir('.'):
+ if osp.splitext(codei)[1] == 'py':
+ shutil.copy(codei, codedir)
+
+ #######################
+ # set matplotlib mode #
+ #######################
+
+ if 'matplotlib_mode' in cfge:
+ try:
+ matplotlib.use(cfge.matplotlib_mode)
+ except:
+ print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode))
+
+ return cfg
+
+def edict_2_dict(x):
+ if isinstance(x, dict):
+ xnew = {}
+ for k in x:
+ xnew[k] = edict_2_dict(x[k])
+ return xnew
+ elif isinstance(x, list):
+ xnew = []
+ for i in range(len(x)):
+ xnew.append( edict_2_dict(x[i]) )
+ return xnew
+ else:
+ return x
+
+def search_experiment_folder(root, exid):
+ target = None
+ for fi in os.listdir(root):
+ if not osp.isdir(osp.join(root, fi)):
+ continue
+ if int(fi.split('_')[0]) == exid:
+ if target is not None:
+ return None # duplicated
+ elif target is None:
+ target = fi
+ return target
diff --git a/lib/cfg_holder.py b/lib/cfg_holder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5cf16c4116931aef32a7275a63965a0d5f23ec7
--- /dev/null
+++ b/lib/cfg_holder.py
@@ -0,0 +1,28 @@
+import copy
+
+def singleton(class_):
+ instances = {}
+ def getinstance(*args, **kwargs):
+ if class_ not in instances:
+ instances[class_] = class_(*args, **kwargs)
+ return instances[class_]
+ return getinstance
+
+##############
+# cfg_holder #
+##############
+
+@singleton
+class cfg_unique_holder(object):
+ def __init__(self):
+ self.cfg = None
+ # this is use to track the main codes.
+ self.code = set()
+ def save_cfg(self, cfg):
+ self.cfg = copy.deepcopy(cfg)
+ def add_code(self, code):
+ """
+ A new main code is reached and
+ its name is added.
+ """
+ self.code.add(code)
diff --git a/lib/log_service.py b/lib/log_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53c5bd006fe64370a6215fbbe985fdc53f838f5
--- /dev/null
+++ b/lib/log_service.py
@@ -0,0 +1,165 @@
+import timeit
+import numpy as np
+import os
+import os.path as osp
+import shutil
+import copy
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from .cfg_holder import cfg_unique_holder as cfguh
+from . import sync
+
+def print_log(*console_info):
+ grank, lrank, _ = sync.get_rank('all')
+ if lrank!=0:
+ return
+
+ console_info = [str(i) for i in console_info]
+ console_info = ' '.join(console_info)
+ print(console_info)
+
+ if grank!=0:
+ return
+
+ log_file = None
+ try:
+ log_file = cfguh().cfg.train.log_file
+ except:
+ try:
+ log_file = cfguh().cfg.eval.log_file
+ except:
+ return
+ if log_file is not None:
+ with open(log_file, 'a') as f:
+ f.write(console_info + '\n')
+
+class distributed_log_manager(object):
+ def __init__(self):
+ self.sum = {}
+ self.cnt = {}
+ self.time_check = timeit.default_timer()
+
+ cfgt = cfguh().cfg.train
+ self.ddp = sync.is_ddp()
+ self.grank, self.lrank, _ = sync.get_rank('all')
+ self.gwsize = sync.get_world_size('global')
+
+ use_tensorboard = cfgt.get('log_tensorboard', False) and (self.grank==0)
+
+ self.tb = None
+ if use_tensorboard:
+ import tensorboardX
+ monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
+ self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))
+
+ def accumulate(self, n, **data):
+ if n < 0:
+ raise ValueError
+
+ for itemn, di in data.items():
+ if itemn in self.sum:
+ self.sum[itemn] += di * n
+ self.cnt[itemn] += n
+ else:
+ self.sum[itemn] = di * n
+ self.cnt[itemn] = n
+
+ def get_mean_value_dict(self):
+ value_gather = [
+ self.sum[itemn]/self.cnt[itemn] \
+ for itemn in sorted(self.sum.keys()) ]
+
+ value_gather_tensor = torch.FloatTensor(value_gather).to(self.lrank)
+ if self.ddp:
+ dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
+ value_gather_tensor /= self.gwsize
+
+ mean = {}
+ for idx, itemn in enumerate(sorted(self.sum.keys())):
+ mean[itemn] = value_gather_tensor[idx].item()
+ return mean
+
+ def tensorboard_log(self, step, data, mode='train', **extra):
+ if self.tb is None:
+ return
+ if mode == 'train':
+ self.tb.add_scalar('other/epochn', extra['epochn'], step)
+ if ('lr' in extra) and (extra['lr'] is not None):
+ self.tb.add_scalar('other/lr', extra['lr'], step)
+ for itemn, di in data.items():
+ if itemn.find('loss') == 0:
+ self.tb.add_scalar('loss/'+itemn, di, step)
+ elif itemn == 'Loss':
+ self.tb.add_scalar('Loss', di, step)
+ else:
+ self.tb.add_scalar('other/'+itemn, di, step)
+ elif mode == 'eval':
+ if isinstance(data, dict):
+ for itemn, di in data.items():
+ self.tb.add_scalar('eval/'+itemn, di, step)
+ else:
+ self.tb.add_scalar('eval', data, step)
+ return
+
+ def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
+ console_info = [
+ 'Iter:{}'.format(itern),
+ 'Epoch:{}'.format(epochn),
+ 'Sample:{}'.format(samplen),]
+
+ if lr is not None:
+ console_info += ['LR:{:.4E}'.format(lr)]
+
+ mean = self.get_mean_value_dict()
+
+ tbstep = itern if tbstep is None else tbstep
+ self.tensorboard_log(
+ tbstep, mean, mode='train',
+ itern=itern, epochn=epochn, lr=lr)
+
+ loss = mean.pop('Loss')
+ mean_info = ['Loss:{:.4f}'.format(loss)] + [
+ '{}:{:.4f}'.format(itemn, mean[itemn]) \
+ for itemn in sorted(mean.keys()) \
+ if itemn.find('loss') == 0
+ ]
+ console_info += mean_info
+ console_info.append('Time:{:.2f}s'.format(
+ timeit.default_timer() - self.time_check))
+ return ' , '.join(console_info)
+
+ def clear(self):
+ self.sum = {}
+ self.cnt = {}
+ self.time_check = timeit.default_timer()
+
+ def tensorboard_close(self):
+ if self.tb is not None:
+ self.tb.close()
+
+# ----- also include some small utils -----
+
+def torch_to_numpy(*argv):
+ if len(argv) > 1:
+ data = list(argv)
+ else:
+ data = argv[0]
+
+ if isinstance(data, torch.Tensor):
+ return data.to('cpu').detach().numpy()
+
+ elif isinstance(data, (list, tuple)):
+ out = []
+ for di in data:
+ out.append(torch_to_numpy(di))
+ return out
+
+ elif isinstance(data, dict):
+ out = {}
+ for ni, di in data.items():
+ out[ni] = torch_to_numpy(di)
+ return out
+
+ else:
+ return data
diff --git a/lib/model_zoo/__init__.py b/lib/model_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c0a57a9e5d66ee79319d7390dedf650ffb05caf
--- /dev/null
+++ b/lib/model_zoo/__init__.py
@@ -0,0 +1,4 @@
+from .common.get_model import get_model
+from .common.get_optimizer import get_optimizer
+from .common.get_scheduler import get_scheduler
+from .common.utils import get_unit
diff --git a/lib/model_zoo/__pycache__/__init__.cpython-310.pyc b/lib/model_zoo/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab5624bce93ef4070463813bfc170d03c3b5a3be
Binary files /dev/null and b/lib/model_zoo/__pycache__/__init__.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/attention.cpython-310.pyc b/lib/model_zoo/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1327f7e1e5c87ec3397f161ec335948cc03bf5d
Binary files /dev/null and b/lib/model_zoo/__pycache__/attention.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/autokl.cpython-310.pyc b/lib/model_zoo/__pycache__/autokl.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51c7f5573f7657590ad14c156630264427d60190
Binary files /dev/null and b/lib/model_zoo/__pycache__/autokl.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/autokl_modules.cpython-310.pyc b/lib/model_zoo/__pycache__/autokl_modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb0f0c590f3da0b0a2e3800a191fe2e4ab85eed0
Binary files /dev/null and b/lib/model_zoo/__pycache__/autokl_modules.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/autokl_utils.cpython-310.pyc b/lib/model_zoo/__pycache__/autokl_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..900214eca03234496bbafa91ce6c9bd1884c2f1e
Binary files /dev/null and b/lib/model_zoo/__pycache__/autokl_utils.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/controlnet.cpython-310.pyc b/lib/model_zoo/__pycache__/controlnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b2cf36508b6b4fe727ee6e69e16a882009abd59
Binary files /dev/null and b/lib/model_zoo/__pycache__/controlnet.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/ddim.cpython-310.pyc b/lib/model_zoo/__pycache__/ddim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54307e4285dc69c77c92ce68a9cc792e848d1852
Binary files /dev/null and b/lib/model_zoo/__pycache__/ddim.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc b/lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc0f188acf6f79a5da30707a8da32096a0fc83ea
Binary files /dev/null and b/lib/model_zoo/__pycache__/diffusion_utils.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/distributions.cpython-310.pyc b/lib/model_zoo/__pycache__/distributions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..571411071c609d7aede73ccc0beff445c65ecb55
Binary files /dev/null and b/lib/model_zoo/__pycache__/distributions.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/ema.cpython-310.pyc b/lib/model_zoo/__pycache__/ema.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b99443f39ef6e6b2c370c0923ee5fad938477e1
Binary files /dev/null and b/lib/model_zoo/__pycache__/ema.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc b/lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6502ed2ee4493bf0ffd8c48ed7e92c23b47bbc18
Binary files /dev/null and b/lib/model_zoo/__pycache__/openaimodel.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/pfd.cpython-310.pyc b/lib/model_zoo/__pycache__/pfd.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3054ce4d21ce639a8d3680eba95ee7757fe3b873
Binary files /dev/null and b/lib/model_zoo/__pycache__/pfd.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/seecoder.cpython-310.pyc b/lib/model_zoo/__pycache__/seecoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c86d96a83ab033cc11bd64791dbfe123f4d9c65
Binary files /dev/null and b/lib/model_zoo/__pycache__/seecoder.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/seecoder_utils.cpython-310.pyc b/lib/model_zoo/__pycache__/seecoder_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19f8c2b05872b3ca4480d32f1718f769ed303ae5
Binary files /dev/null and b/lib/model_zoo/__pycache__/seecoder_utils.cpython-310.pyc differ
diff --git a/lib/model_zoo/__pycache__/swin.cpython-310.pyc b/lib/model_zoo/__pycache__/swin.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddb75862ecd1b146d9f6b3917438793ea98fba1e
Binary files /dev/null and b/lib/model_zoo/__pycache__/swin.cpython-310.pyc differ
diff --git a/lib/model_zoo/attention.py b/lib/model_zoo/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e9fbb4c5de3dbe086a03140ae07fd9d8d2dee61
--- /dev/null
+++ b/lib/model_zoo/attention.py
@@ -0,0 +1,540 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+from .diffusion_utils import checkpoint
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.inner_dim = inner_dim
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+ def forward_next(self, x, context=None, mask=None):
+ assert mask is None, 'not supported yet'
+ x0 = rearrange(x, 'b n c -> n b c')
+ if context is not None:
+ c0 = rearrange(context, 'b n c -> n b c')
+ else:
+ c0 = x0
+ r, _ = F.multi_head_attention_forward(
+ x0, c0, c0,
+ embed_dim_to_check = self.inner_dim,
+ num_heads = self.heads,
+ in_proj_weight = None, in_proj_bias = None,
+ bias_k = None, bias_v = None,
+ add_zero_attn = False, dropout_p = 0,
+ out_proj_weight = self.to_out[0].weight,
+ out_proj_bias = self.to_out[0].bias,
+ use_separate_proj_weight = True,
+ q_proj_weight = self.to_q.weight,
+ k_proj_weight = self.to_k.weight,
+ v_proj_weight = self.to_v.weight,)
+ r = rearrange(r, 'n b c -> b n c')
+ r = self.to_out[1](r) # dropout
+ return r
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+##########################
+# transformer no context #
+##########################
+
+class BasicTransformerBlockNoContext(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True):
+ super().__init__()
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
+ dropout=dropout, context_dim=None)
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
+ dropout=dropout, context_dim=None)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
+
+ def _forward(self, x):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x)) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+class SpatialTransformerNoContext(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0.,):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlockNoContext(inner_dim, n_heads, d_head, dropout=dropout)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ for block in self.transformer_blocks:
+ x = block(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ x = self.proj_out(x)
+ return x + x_in
+
+
+#######################################
+# Spatial Transformer with Two Branch #
+#######################################
+
+class DualSpatialTransformer(nn.Module):
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+
+ # First crossattn
+ self.norm_0 = Normalize(in_channels)
+ self.proj_in_0 = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ self.transformer_blocks_0 = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
+ disable_self_attn=disable_self_attn)
+ for d in range(depth)]
+ )
+ self.proj_out_0 = zero_module(nn.Conv2d(
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
+
+ # Second crossattn
+ self.norm_1 = Normalize(in_channels)
+ self.proj_in_1 = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ self.transformer_blocks_1 = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
+ disable_self_attn=disable_self_attn)
+ for d in range(depth)]
+ )
+ self.proj_out_1 = zero_module(nn.Conv2d(
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
+
+ def forward(self, x, context=None, which=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ if which==0:
+ norm, proj_in, blocks, proj_out = \
+ self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
+ elif which==1:
+ norm, proj_in, blocks, proj_out = \
+ self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
+ else:
+ # assert False, 'DualSpatialTransformer forward with a invalid which branch!'
+ # import numpy.random as npr
+ # rwhich = 0 if npr.rand() < which else 1
+ # context = context[rwhich]
+ # if rwhich==0:
+ # norm, proj_in, blocks, proj_out = \
+ # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
+ # elif rwhich==1:
+ # norm, proj_in, blocks, proj_out = \
+ # self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
+
+ # import numpy.random as npr
+ # rwhich = 0 if npr.rand() < 0.33 else 1
+ # if rwhich==0:
+ # context = context[rwhich]
+ # norm, proj_in, blocks, proj_out = \
+ # self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
+ # else:
+
+ norm, proj_in, blocks, proj_out = \
+ self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
+ x0 = norm(x)
+ x0 = proj_in(x0)
+ x0 = rearrange(x0, 'b c h w -> b (h w) c').contiguous()
+ for block in blocks:
+ x0 = block(x0, context=context[0])
+ x0 = rearrange(x0, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ x0 = proj_out(x0)
+
+ norm, proj_in, blocks, proj_out = \
+ self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
+ x1 = norm(x)
+ x1 = proj_in(x1)
+ x1 = rearrange(x1, 'b c h w -> b (h w) c').contiguous()
+ for block in blocks:
+ x1 = block(x1, context=context[1])
+ x1 = rearrange(x1, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ x1 = proj_out(x1)
+ return x0*which + x1*(1-which) + x_in
+
+ x = norm(x)
+ x = proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ for block in blocks:
+ x = block(x, context=context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ x = proj_out(x)
+ return x + x_in
diff --git a/lib/model_zoo/autokl.py b/lib/model_zoo/autokl.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d1825bdf23a18aa98c47550117d23a2f6093d43
--- /dev/null
+++ b/lib/model_zoo/autokl.py
@@ -0,0 +1,166 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from contextlib import contextmanager
+from lib.model_zoo.common.get_model import get_model, register
+
+# from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+
+from .autokl_modules import Encoder, Decoder
+from .distributions import DiagonalGaussianDistribution
+
+from .autokl_utils import LPIPSWithDiscriminator
+
+@register('autoencoderkl')
+class AutoencoderKL(nn.Module):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,):
+ super().__init__()
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ if lossconfig is not None:
+ self.loss = LPIPSWithDiscriminator(**lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+
+ @torch.no_grad()
+ def encode(self, x, out_posterior=False):
+ return self.encode_trainable(x, out_posterior)
+
+ def encode_trainable(self, x, out_posterior=False):
+ x = x*2-1
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ if out_posterior:
+ return posterior
+ else:
+ return posterior.sample()
+
+ @torch.no_grad()
+ def decode(self, z):
+ dec = self.decode_trainable(z)
+ dec = torch.clamp(dec, 0, 1)
+ return dec
+
+ def decode_trainable(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ dec = (dec+1)/2
+ return dec
+
+ def apply_model(self, input, sample_posterior=True):
+ posterior = self.encode_trainable(input, out_posterior=True)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode_trainable(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def forward(self, x, optimizer_idx, global_step):
+ reconstructions, posterior = self.apply_model(x)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step,
+ last_layer=self.get_last_layer(), split="train")
+ return aeloss, log_dict_ae
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(x, reconstructions, posterior, optimizer_idx, global_step=global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ return discloss, log_dict_disc
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+@register('autoencoderkl_customnorm')
+class AutoencoderKL_CustomNorm(AutoencoderKL):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073])
+ self.std = torch.Tensor([0.26862954, 0.26130258, 0.27577711])
+
+ def encode_trainable(self, x, out_posterior=False):
+ m = self.mean[None, :, None, None].to(z.device).to(z.dtype)
+ s = self.std[None, :, None, None].to(z.device).to(z.dtype)
+ x = (x-m)/s
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ if out_posterior:
+ return posterior
+ else:
+ return posterior.sample()
+
+ def decode_trainable(self, z):
+ m = self.mean[None, :, None, None].to(z.device).to(z.dtype)
+ s = self.std[None, :, None, None].to(z.device).to(z.dtype)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ dec = (dec+1)/2
+ return dec
diff --git a/lib/model_zoo/autokl_modules.py b/lib/model_zoo/autokl_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..146c5b241feb8b6f46946b29534f6212fab1ad85
--- /dev/null
+++ b/lib/model_zoo/autokl_modules.py
@@ -0,0 +1,835 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+# from .diffusion_utils import instantiate_from_config
+from .attention import LinearAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self, ch_mult:list, in_channels,
+ pretrained_model:nn.Module=None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
+ stride=1,padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+
+ @torch.no_grad()
+ def encode_with_pretrained(self,x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self,x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model,self.downsampler):
+ z = submodel(z,temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z,'b c h w -> b (h w) c')
+ return z
+
diff --git a/lib/model_zoo/autokl_utils.py b/lib/model_zoo/autokl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..903fdded9ce0a771090648827322a04c8fccf8f5
--- /dev/null
+++ b/lib/model_zoo/autokl_utils.py
@@ -0,0 +1,400 @@
+import torch
+import torch.nn as nn
+import functools
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True,
+ allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height*width*torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+#################
+# Discriminator #
+#################
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm') != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
+
+#########
+# LPIPS #
+#########
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+class NetLinLayer(nn.Module):
+ """ A single linear layer which does a 1x1 conv """
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = [nn.Dropout(), ] if (use_dropout) else []
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
+ self.model = nn.Sequential(*layers)
+
+from collections import namedtuple
+from torchvision import models
+from torchvision.models import VGG16_Weights
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ if pretrained:
+ vgg_pretrained_features = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+def normalize_tensor(x,eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
+ return x/(norm_factor+eps)
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2,3],keepdim=keepdim)
+
+def get_ckpt_path(*args, **kwargs):
+ return 'pretrained/lpips.pth'
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+############
+# The loss #
+############
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1. - logits_real))
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
+ return d_loss
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_loss="hinge"):
+
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train",
+ weights=None):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights*nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+
+ log = {"Loss": loss.clone().detach().mean(),
+ "logvar": self.logvar.detach(),
+ "loss_kl": kl_loss.detach().mean(),
+ "loss_nll": nll_loss.detach().mean(),
+ "loss_rec": rec_loss.detach().mean(),
+ "d_weight": d_weight.detach(),
+ "disc_factor": torch.tensor(disc_factor),
+ "loss_g": g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"Loss": d_loss.clone().detach().mean(),
+ "loss_disc": d_loss.clone().detach().mean(),
+ "logits_real": logits_real.detach().mean(),
+ "logits_fake": logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/lib/model_zoo/clip.py b/lib/model_zoo/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..7967b358161bc5477acb018a40f84ae37b3d451f
--- /dev/null
+++ b/lib/model_zoo/clip.py
@@ -0,0 +1,788 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+from lib.model_zoo.common.get_model import register
+
+symbol = 'clip'
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+from transformers import CLIPTokenizer, CLIPTextModel
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+@register('clip_text_context_encoder_sdv1')
+class CLIPTextContextEncoderSDv1(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True): # clip-vit-base-patch32
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ with torch.no_grad():
+ batch_encoding = self.tokenizer(
+ text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ max_token_n = self.transformer.text_model.embeddings.position_ids.shape[1]
+ positional_ids = torch.arange(max_token_n)[None].to(self.device)
+ outputs = self.transformer(
+ input_ids=tokens,
+ position_ids=positional_ids, )
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+#############################
+# copyed from justin's code #
+#############################
+
+@register('clip_image_context_encoder_justin')
+class CLIPImageContextEncoderJustin(AbstractEncoder):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model='ViT-L/14',
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ from . import clip_justin
+ self.model, _ = clip_justin.load(name=model, device=device, jit=jit)
+ self.device = device
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ # I didn't call this originally, but seems like it was frozen anyway
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def preprocess(self, x):
+ import kornia
+ # Expects inputs in the range -1, 1
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ return self.model.encode_image(self.preprocess(x)).float()
+
+ def encode(self, im):
+ return self(im).unsqueeze(1)
+
+###############
+# for vd next #
+###############
+
+from transformers import CLIPModel
+
+@register('clip_text_context_encoder')
+class CLIPTextContextEncoder(AbstractEncoder):
+ def __init__(self,
+ version="openai/clip-vit-large-patch14",
+ max_length=77,
+ fp16=False, ):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.model = CLIPModel.from_pretrained(version)
+ self.max_length = max_length
+ self.fp16 = fp16
+ self.freeze()
+
+ def get_device(self):
+ # A trick to get device
+ return self.model.text_projection.weight.device
+
+ def freeze(self):
+ self.model = self.model.eval()
+ self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def encode(self, text):
+ batch_encoding = self.tokenizer(
+ text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.get_device())
+ outputs = self.model.text_model(input_ids=tokens)
+ z = self.model.text_projection(outputs.last_hidden_state)
+ z_pooled = self.model.text_projection(outputs.pooler_output)
+ z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True)
+ return z
+
+from transformers import CLIPProcessor
+
+@register('clip_image_context_encoder')
+class CLIPImageContextEncoder(AbstractEncoder):
+ def __init__(self,
+ version="openai/clip-vit-large-patch14",
+ fp16=False, ):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.processor = CLIPProcessor.from_pretrained(version)
+ self.model = CLIPModel.from_pretrained(version)
+ self.fp16 = fp16
+ self.freeze()
+
+ def get_device(self):
+ # A trick to get device
+ return self.model.text_projection.weight.device
+
+ def freeze(self):
+ self.model = self.model.eval()
+ self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def _encode(self, images):
+ if isinstance(images, torch.Tensor):
+ import torchvision.transforms as tvtrans
+ images = [tvtrans.ToPILImage()(i) for i in images]
+ inputs = self.processor(images=images, return_tensors="pt")
+ pixels = inputs['pixel_values'].half() if self.fp16 else inputs['pixel_values']
+ pixels = pixels.to(self.get_device())
+ outputs = self.model.vision_model(pixel_values=pixels)
+ z = outputs.last_hidden_state
+ z = self.model.vision_model.post_layernorm(z)
+ z = self.model.visual_projection(z)
+ z_pooled = z[:, 0:1]
+ z = z / torch.norm(z_pooled, dim=-1, keepdim=True)
+ return z
+
+ @torch.no_grad()
+ def _encode_wmask(self, images, masks):
+ assert isinstance(masks, torch.Tensor)
+ assert (len(masks.shape)==4) and (masks.shape[1]==1)
+ masks = torch.clamp(masks, 0, 1)
+ masked_images = images*masks
+ masks = masks.float()
+ masks = F.interpolate(masks, [224, 224], mode='bilinear')
+ if masks.sum() == masks.numel():
+ return self._encode(images)
+
+ device = images.device
+ dtype = images.dtype
+ gscale = masks.mean(axis=[1, 2, 3], keepdim=True).flatten(2)
+
+ vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
+ vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
+ mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
+ vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
+ vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
+ vtoken_mask = torch.concat([gscale, vtoken_mask], axis=1)
+
+ import types
+ def customized_embedding_forward(self, pixel_values):
+ batch_size = pixel_values.shape[0]
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ embeddings = embeddings*vtoken_mask.to(embeddings.dtype)
+ return embeddings
+
+ old_forward = self.model.vision_model.embeddings.forward
+ self.model.vision_model.embeddings.forward = types.MethodType(
+ customized_embedding_forward, self.model.vision_model.embeddings)
+
+ z = self._encode(images)
+ self.model.vision_model.embeddings.forward = old_forward
+ z = z * vtoken_mask.to(dtype)
+ return z
+
+ # def _encode_wmask(self, images, masks):
+ # assert isinstance(masks, torch.Tensor)
+ # assert (len(masks.shape)==4) and (masks.shape[1]==1)
+ # masks = torch.clamp(masks, 0, 1)
+ # masks = masks.float()
+ # masks = F.interpolate(masks, [224, 224], mode='bilinear')
+ # if masks.sum() == masks.numel():
+ # return self._encode(images)
+
+ # device = images.device
+ # dtype = images.dtype
+
+ # vtoken_kernel_size = self.model.vision_model.embeddings.patch_embedding.kernel_size
+ # vtoken_stride = self.model.vision_model.embeddings.patch_embedding.stride
+ # mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, requires_grad=False).float()
+ # vtoken_mask = torch.nn.functional.conv2d(masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
+ # vtoken_mask = vtoken_mask/np.prod(vtoken_kernel_size)
+
+ # z = self._encode(images)
+ # z[:, 1:, :] = z[:, 1:, :] * vtoken_mask.to(dtype)
+ # z[:, 0, :] = 0
+ # return z
+
+ def encode(self, images, masks=None):
+ if masks is None:
+ return self._encode(images)
+ else:
+ return self._encode_wmask(images, masks)
+
+@register('clip_image_context_encoder_position_agnostic')
+class CLIPImageContextEncoderPA(CLIPImageContextEncoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ import types
+ def customized_embedding_forward(self, pixel_values):
+ batch_size = pixel_values.shape[0]
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ pembeddings = self.position_embedding(self.position_ids)
+ pembeddings = torch.cat([
+ pembeddings[:, 0:1],
+ pembeddings[:, 1: ].mean(dim=1, keepdim=True).repeat(1, 256, 1)], dim=1)
+ embeddings = embeddings + pembeddings
+ return embeddings
+
+ self.model.vision_model.embeddings.forward = types.MethodType(
+ customized_embedding_forward, self.model.vision_model.embeddings)
+
+##############
+# from sd2.0 #
+##############
+
+import open_clip
+import torch.nn.functional as F
+
+@register('openclip_text_context_encoder_sdv2')
+class FrozenOpenCLIPTextEmbedderSDv2(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ #"pooled",
+ "last",
+ "penultimate"
+ ]
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+@register('openclip_text_context_encoder')
+class FrozenOpenCLIPTextEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ def __init__(self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ max_length=77,
+ freeze=True,):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+ self.max_length = max_length
+ self.device = 'cpu'
+ if freeze:
+ self.freeze()
+
+ def to(self, device):
+ self.device = device
+ super().to(device)
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ self.device = self.model.ln_final.weight.device # urgly trick
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.model.transformer(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ x_pool = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection
+ # x_pool_debug = F.normalize(x_pool, dim=-1)
+ x = x @ self.model.text_projection
+ x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+@register('openclip_image_context_encoder')
+class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ def __init__(self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ freeze=True,):
+ super().__init__()
+ model, _, preprocess = open_clip.create_model_and_transforms(
+ arch, device=torch.device('cpu'), pretrained=version)
+ self.model = model.visual
+ self.device = 'cpu'
+ import torchvision.transforms as tvtrans
+ # we only need resize & normalization
+ preprocess.transforms[0].size = [224, 224] # make it more precise
+ self.preprocess = tvtrans.Compose([
+ preprocess.transforms[0],
+ preprocess.transforms[4],])
+ if freeze:
+ self.freeze()
+
+ def to(self, device):
+ self.device = device
+ super().to(device)
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, image):
+ z = self.preprocess(image)
+ z = self.encode_with_transformer(z)
+ return z
+
+ def encode_with_transformer(self, image):
+ x = self.model.conv1(image)
+ x = x.reshape(x.shape[0], x.shape[1], -1)
+ x = x.permute(0, 2, 1)
+ x = torch.cat([
+ self.model.class_embedding.to(x.dtype)
+ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+ x], dim=1)
+ x = x + self.model.positional_embedding.to(x.dtype)
+ x = self.model.ln_pre(x)
+ x = x.permute(1, 0, 2)
+ x = self.model.transformer(x)
+ x = x.permute(1, 0, 2)
+
+ x = self.model.ln_post(x)
+ if self.model.proj is not None:
+ x = x @ self.model.proj
+
+ x_pool = x[:, 0, :]
+ # x_pool_debug = self.model(image)
+ # x_pooln_debug = F.normalize(x_pool_debug, dim=-1)
+ x = x / x_pool.norm(dim=1, keepdim=True).unsqueeze(1)
+ return x
+
+ def _encode(self, image):
+ return self(image)
+
+ def _encode_wmask(self, images, masks):
+ z = self._encode(images)
+ device = z.device
+ vtoken_kernel_size = self.model.conv1.kernel_size
+ vtoken_stride = self.model.conv1.stride
+ mask_kernal = torch.ones([1, 1, *vtoken_kernel_size], device=device, dtype=z.dtype, requires_grad=False)
+ mask_kernal /= np.prod(vtoken_kernel_size)
+
+ assert isinstance(masks, torch.Tensor)
+ assert (len(masks.shape)==4) and (masks.shape[1]==1)
+ masks = torch.clamp(masks, 0, 1)
+ masks = F.interpolate(masks, [224, 224], mode='bilinear')
+
+ vtoken_mask = torch.nn.functional.conv2d(1-masks, mask_kernal, stride=vtoken_stride).flatten(2).transpose(1, 2)
+ z[:, 1:, :] = z[:, 1:, :] * vtoken_mask
+ z[:, 0, :] = 0
+ return z
+
+ def encode(self, images, masks=None):
+ if masks is None:
+ return self._encode(images)
+ else:
+ return self._encode_wmask(images, masks)
+
+############################
+# def customized tokenizer #
+############################
+
+from open_clip import SimpleTokenizer
+
+@register('openclip_text_context_encoder_sdv2_customized_tokenizer_v1')
+class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV1(FrozenOpenCLIPTextEmbedderSDv2):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ def __init__(self, customized_tokens, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if isinstance(customized_tokens, str):
+ customized_tokens = [customized_tokens]
+ self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens)
+ self.num_regular_tokens = self.model.token_embedding.weight.shape[0]
+ self.embedding_dim = self.model.ln_final.weight.shape[0]
+ self.customized_token_embedding = nn.Embedding(
+ len(customized_tokens), embedding_dim=self.embedding_dim)
+ nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
+
+ def tokenize(self, texts):
+ if isinstance(texts, str):
+ texts = [texts]
+ sot_token = self.tokenizer.encoder[""]
+ eot_token = self.tokenizer.encoder[""]
+ all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
+ maxn = self.num_regular_tokens
+ regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens]
+ token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens]
+ customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens]
+ return regular_tokens, customized_tokens, token_mask
+
+ def pad_to_length(self, tokens, context_length=77, eot_token=None):
+ result = torch.zeros(len(tokens), context_length, dtype=torch.long)
+ eot_token = self.tokenizer.encoder[""] if eot_token is None else eot_token
+ for i, tokens in enumerate(tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ tokens[-1] = eot_token
+ result[i, :len(tokens)] = torch.tensor(tokens)
+ return result
+
+ def forward(self, text):
+ self.device = self.model.ln_final.weight.device # urgly trick
+ regular_tokens, customized_tokens, token_mask = self.tokenize(text)
+ regular_tokens = self.pad_to_length(regular_tokens).to(self.device)
+ customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device)
+ token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device)
+ z0 = self.encode_with_transformer(regular_tokens)
+ z1 = self.customized_token_embedding(customized_tokens)
+ token_mask = token_mask[:, :, None].type(z0.dtype)
+ z = z0 * (1-token_mask) + z1 * token_mask
+ return z
+
+@register('openclip_text_context_encoder_sdv2_customized_tokenizer_v2')
+class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2(FrozenOpenCLIPTextEmbedderSDv2):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ def __init__(self, customized_tokens, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if isinstance(customized_tokens, str):
+ customized_tokens = [customized_tokens]
+ self.tokenizer = open_clip.SimpleTokenizer(special_tokens=customized_tokens)
+ self.num_regular_tokens = self.model.token_embedding.weight.shape[0]
+ self.embedding_dim = self.model.token_embedding.weight.shape[1]
+ self.customized_token_embedding = nn.Embedding(
+ len(customized_tokens), embedding_dim=self.embedding_dim)
+ nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
+
+ def tokenize(self, texts):
+ if isinstance(texts, str):
+ texts = [texts]
+ sot_token = self.tokenizer.encoder[""]
+ eot_token = self.tokenizer.encoder[""]
+ all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
+ maxn = self.num_regular_tokens
+ regular_tokens = [[ti if ti < maxn else 0 for ti in tokens] for tokens in all_tokens]
+ token_mask = [[0 if ti < maxn else 1 for ti in tokens] for tokens in all_tokens]
+ customized_tokens = [[ti-maxn if ti >= maxn else 0 for ti in tokens] for tokens in all_tokens]
+ return regular_tokens, customized_tokens, token_mask
+
+ def pad_to_length(self, tokens, context_length=77, eot_token=None):
+ result = torch.zeros(len(tokens), context_length, dtype=torch.long)
+ eot_token = self.tokenizer.encoder[""] if eot_token is None else eot_token
+ for i, tokens in enumerate(tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ tokens[-1] = eot_token
+ result[i, :len(tokens)] = torch.tensor(tokens)
+ return result
+
+ def forward(self, text):
+ self.device = self.model.token_embedding.weight.device # urgly trick
+ regular_tokens, customized_tokens, token_mask = self.tokenize(text)
+ regular_tokens = self.pad_to_length(regular_tokens).to(self.device)
+ customized_tokens = self.pad_to_length(customized_tokens, eot_token=0).to(self.device)
+ token_mask = self.pad_to_length(token_mask, eot_token=0).to(self.device)
+ z = self.encode_with_transformer(regular_tokens, customized_tokens, token_mask)
+ return z
+
+ def encode_with_transformer(self, token, customized_token, token_mask):
+ x0 = self.model.token_embedding(token)
+ x1 = self.customized_token_embedding(customized_token)
+ token_mask = token_mask[:, :, None].type(x0.dtype)
+ x = x0 * (1-token_mask) + x1 * token_mask
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+class ln_freezed_temp(nn.LayerNorm):
+ def forward(self, x):
+ self.weight.requires_grad = False
+ self.bias.requires_grad = False
+ return super().forward(x)
+
+@register('openclip_text_context_encoder_sdv2_customized_tokenizer_v3')
+class FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV3(FrozenOpenCLIPEmbedderSDv2CustomizedTokenizerV2):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ def __init__(self, customized_tokens, texpand=4, lora_rank=None, lora_bias_trainable=True, *args, **kwargs):
+ super().__init__(customized_tokens, *args, **kwargs)
+ if isinstance(customized_tokens, str):
+ customized_tokens = [customized_tokens]
+ self.texpand = texpand
+ self.customized_token_embedding = nn.Embedding(
+ len(customized_tokens)*texpand, embedding_dim=self.embedding_dim)
+ nn.init.normal_(self.customized_token_embedding.weight, std=0.02)
+
+ if lora_rank is not None:
+ from .lora import freeze_param, freeze_module, to_lora
+ def convert_resattnblock(module):
+ module.ln_1.__class__ = ln_freezed_temp
+ # freeze_module(module.ln_1)
+ module.attn = to_lora(module.attn, lora_rank, lora_bias_trainable)
+ module.ln_2.__class__ = ln_freezed_temp
+ # freeze_module(module.ln_2)
+ module.mlp.c_fc = to_lora(module.mlp.c_fc, lora_rank, lora_bias_trainable)
+ module.mlp.c_proj = to_lora(module.mlp.c_proj, lora_rank, lora_bias_trainable)
+ freeze_param(self.model, 'positional_embedding')
+ freeze_param(self.model, 'text_projection')
+ freeze_param(self.model, 'logit_scale')
+ for idx, resattnblock in enumerate(self.model.transformer.resblocks):
+ convert_resattnblock(resattnblock)
+ freeze_module(self.model.token_embedding)
+ self.model.ln_final.__class__ = ln_freezed_temp
+ # freeze_module(self.model.ln_final)
+
+ def tokenize(self, texts):
+ if isinstance(texts, str):
+ texts = [texts]
+ sot_token = self.tokenizer.encoder[""]
+ eot_token = self.tokenizer.encoder[""]
+ all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
+ maxn = self.num_regular_tokens
+ regular_tokens = [[[ti] if ti < maxn else [0]*self.texpand for ti in tokens] for tokens in all_tokens]
+ token_mask = [[[ 0] if ti < maxn else [1]*self.texpand for ti in tokens] for tokens in all_tokens]
+ custom_tokens = [[[ 0] if ti < maxn else [
+ (ti-maxn)*self.texpand+ii for ii in range(self.texpand)]
+ for ti in tokens] for tokens in all_tokens]
+
+ from itertools import chain
+ regular_tokens = [[i for i in chain(*tokens)] for tokens in regular_tokens]
+ token_mask = [[i for i in chain(*tokens)] for tokens in token_mask]
+ custom_tokens = [[i for i in chain(*tokens)] for tokens in custom_tokens]
+ return regular_tokens, custom_tokens, token_mask
+
+###################
+# clip expandable #
+###################
+
+@register('clip_text_sdv1_customized_embedding')
+class CLIPTextSD1CE(nn.Module):
+ def __init__(
+ self,
+ replace_info="text|elon musk",
+ version="openai/clip-vit-large-patch14",
+ max_length=77):
+ super().__init__()
+
+ self.name = 'clip_text_sdv1_customized_embedding'
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.reset_replace_info(replace_info)
+ self.max_length = max_length
+ self.special_token = ""
+
+ def reset_replace_info(self, replace_info):
+ rtype, rpara = replace_info.split("|")
+ self.replace_type = rtype
+ if rtype == "token_embedding":
+ ce_num = int(rpara)
+ ce_dim = self.transformer.text_model.embeddings.token_embedding.weight.size(1)
+ self.cembedding = nn.Embedding(ce_num, ce_dim)
+ self.cembedding = self.cembedding.to(self.get_device())
+ elif rtype == "context_embedding":
+ ce_num = int(rpara)
+ ce_dim = self.transformer.text_model.encoder.layers[-1].layer_norm2.weight.size(0)
+ self.cembedding = nn.Embedding(ce_num, ce_dim)
+ self.cembedding = self.cembedding.to(self.get_device())
+ else:
+ assert rtype=="text"
+ self.replace_type = "text"
+ self.replace_string = rpara
+ self.cembedding = None
+
+ def get_device(self):
+ return self.transformer.text_model.embeddings.token_embedding.weight.device
+
+ def position_to_mask(self, tokens, positions):
+ mask = torch.zeros_like(tokens)
+ for idxb, idxs, idxe in zip(*positions):
+ mask[idxb, idxs:idxe] = 1
+ return mask
+
+ def forward(self, text):
+ tokens, positions = self.tokenize(text)
+ mask = self.position_to_mask(tokens, positions)
+ max_token_n = tokens.size(1)
+ positional_ids = torch.arange(max_token_n)[None].to(self.get_device())
+
+ if self.replace_what == 'token_embedding':
+ cembeds = self.cembedding(tokens * mask)
+
+ def embedding_customized_forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None,):
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+ inputs_embeds = inputs_embeds * (1-mask.float())[:, :, None]
+ inputs_embeds = inputs_embeds + cembeds
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+ return embeddings
+
+ import types
+ self.transformer.text_model.embeddings.forward = types.MethodType(
+ embedding_customized_forward, self.transformer.text_model.embeddings)
+
+ else:
+ # TODO: Implement
+ assert False
+
+ outputs = self.transformer(
+ input_ids=tokens,
+ position_ids=positional_ids, )
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+ @torch.no_grad()
+ def tokenize(self, text):
+ if isinstance(text, str):
+ text = [text]
+
+ bos_special_text = "<|startoftext|>"
+ text = [ti.replace(self.special_token, bos_special_text) for ti in text]
+
+ batch_encoding = self.tokenizer(
+ text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"]
+
+ bosid = tokens[0, 0]
+ eosid = tokens[0, -1]
+ bs, maxn = tokens.shape
+
+ if self.replace_what in ['token_embedding', 'context_embedding']:
+ newtokens = []
+ ce_num = self.cembedding.weight.size(0)
+ idxi = []; idxstart = []; idxend = [];
+ for idxii, tokeni in enumerate(tokens):
+ newtokeni = []
+ idxjj = 0
+ for ii, tokenii in enumerate(tokeni):
+ if (tokenii == bosid) and (ii != 0):
+ newtokeni.extend([i for i in range(ce_num)])
+ idxi.append(idxii); idxstart.append(idxjj);
+ idxjj += ce_num
+ idxjj_record = idxjj if idxjj<=maxn-1 else maxn-1
+ idxend.append(idxjj_record);
+ else:
+ newtokeni.extend([tokenii])
+ idxjj += 1
+ newtokeni = newtokeni[:maxn]
+ newtokeni[-1] = eosid
+ newtokens.append(newtokeni)
+ return torch.LongTensor(newtokens).to(self.get_device()), (idxi, idxstart, idxend)
+ else:
+ # TODO: Implement
+ assert False
diff --git a/lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc b/lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8d6508431d8f56c330f61260ea2a5621a7614db
Binary files /dev/null and b/lib/model_zoo/common/__pycache__/get_model.cpython-310.pyc differ
diff --git a/lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc b/lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..caac8c7bc5a27ba85ac9d9c338affb11dbdc38d6
Binary files /dev/null and b/lib/model_zoo/common/__pycache__/get_optimizer.cpython-310.pyc differ
diff --git a/lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc b/lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9711a5af1b751c87df46aa0a7a7b8bd0f2d2ffb
Binary files /dev/null and b/lib/model_zoo/common/__pycache__/get_scheduler.cpython-310.pyc differ
diff --git a/lib/model_zoo/common/__pycache__/utils.cpython-310.pyc b/lib/model_zoo/common/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5eb0f8b0d4c906251a7c02d04663a15f9ec5ac9a
Binary files /dev/null and b/lib/model_zoo/common/__pycache__/utils.cpython-310.pyc differ
diff --git a/lib/model_zoo/common/get_model.py b/lib/model_zoo/common/get_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1534fbd75e3f5932d670e8bcef25ac8f9f9bf628
--- /dev/null
+++ b/lib/model_zoo/common/get_model.py
@@ -0,0 +1,124 @@
+from email.policy import strict
+import torch
+import torchvision.models
+import os.path as osp
+import copy
+from ...log_service import print_log
+from .utils import \
+ get_total_param, get_total_param_sum, \
+ get_unit
+
+# def load_state_dict(net, model_path):
+# if isinstance(net, dict):
+# for ni, neti in net.items():
+# paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
+# new_paras = neti.state_dict()
+# new_paras.update(paras)
+# neti.load_state_dict(new_paras)
+# else:
+# paras = torch.load(model_path, map_location=torch.device('cpu'))
+# new_paras = net.state_dict()
+# new_paras.update(paras)
+# net.load_state_dict(new_paras)
+# return
+
+# def save_state_dict(net, path):
+# if isinstance(net, (torch.nn.DataParallel,
+# torch.nn.parallel.DistributedDataParallel)):
+# torch.save(net.module.state_dict(), path)
+# else:
+# torch.save(net.state_dict(), path)
+
+def singleton(class_):
+ instances = {}
+ def getinstance(*args, **kwargs):
+ if class_ not in instances:
+ instances[class_] = class_(*args, **kwargs)
+ return instances[class_]
+ return getinstance
+
+def preprocess_model_args(args):
+ # If args has layer_units, get the corresponding
+ # units.
+ # If args get backbone, get the backbone model.
+ args = copy.deepcopy(args)
+ if 'layer_units' in args:
+ layer_units = [
+ get_unit()(i) for i in args.layer_units
+ ]
+ args.layer_units = layer_units
+ if 'backbone' in args:
+ args.backbone = get_model()(args.backbone)
+ return args
+
+@singleton
+class get_model(object):
+ def __init__(self):
+ self.model = {}
+
+ def register(self, model, name):
+ self.model[name] = model
+
+ def __call__(self, cfg, verbose=True):
+ """
+ Construct model based on the config.
+ """
+ if cfg is None:
+ return None
+
+ t = cfg.type
+
+ # the register is in each file
+ if t.find('pfd')==0:
+ from .. import pfd
+ elif t=='autoencoderkl':
+ from .. import autokl
+ elif (t.find('clip')==0) or (t.find('openclip')==0):
+ from .. import clip
+ elif t.find('openai_unet')==0:
+ from .. import openaimodel
+ elif t.find('controlnet')==0:
+ from .. import controlnet
+ elif t.find('seecoder')==0:
+ from .. import seecoder
+ elif t.find('swin')==0:
+ from .. import swin
+
+ args = preprocess_model_args(cfg.args)
+ net = self.model[t](**args)
+
+ pretrained = cfg.get('pretrained', None)
+ if pretrained is None: # backward compatible
+ pretrained = cfg.get('pth', None)
+ map_location = cfg.get('map_location', 'cpu')
+ strict_sd = cfg.get('strict_sd', True)
+
+ if pretrained is not None:
+ if osp.splitext(pretrained)[1] == '.pth':
+ sd = torch.load(pretrained, map_location=map_location)
+ elif osp.splitext(pretrained)[1] == '.ckpt':
+ sd = torch.load(pretrained, map_location=map_location)['state_dict']
+ elif osp.splitext(pretrained)[1] == '.safetensors':
+ from safetensors.torch import load_file
+ from collections import OrderedDict
+ sd = load_file(pretrained, map_location)
+ sd = OrderedDict(sd)
+ net.load_state_dict(sd, strict=strict_sd)
+ if verbose:
+ print_log('Load model from [{}] strict [{}].'.format(pretrained, strict_sd))
+
+ # display param_num & param_sum
+ if verbose:
+ print_log(
+ 'Load {} with total {} parameters,'
+ '{:.3f} parameter sum.'.format(
+ t,
+ get_total_param(net),
+ get_total_param_sum(net) ))
+ return net
+
+def register(name):
+ def wrapper(class_):
+ get_model().register(class_, name)
+ return class_
+ return wrapper
diff --git a/lib/model_zoo/common/get_optimizer.py b/lib/model_zoo/common/get_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f2820ce6734fe0929963e5ba92c8fd4c4fd6ddd
--- /dev/null
+++ b/lib/model_zoo/common/get_optimizer.py
@@ -0,0 +1,47 @@
+import torch
+import torch.optim as optim
+import numpy as np
+import itertools
+
+def singleton(class_):
+ instances = {}
+ def getinstance(*args, **kwargs):
+ if class_ not in instances:
+ instances[class_] = class_(*args, **kwargs)
+ return instances[class_]
+ return getinstance
+
+class get_optimizer(object):
+ def __init__(self):
+ self.optimizer = {}
+ self.register(optim.SGD, 'sgd')
+ self.register(optim.Adam, 'adam')
+ self.register(optim.AdamW, 'adamw')
+
+ def register(self, optim, name):
+ self.optimizer[name] = optim
+
+ def __call__(self, net, cfg):
+ if cfg is None:
+ return None
+ t = cfg.type
+ if isinstance(net, (torch.nn.DataParallel,
+ torch.nn.parallel.DistributedDataParallel)):
+ netm = net.module
+ else:
+ netm = net
+ pg = getattr(netm, 'parameter_group', None)
+
+ if pg is not None:
+ params = []
+ for group_name, module_or_para in pg.items():
+ if not isinstance(module_or_para, list):
+ module_or_para = [module_or_para]
+
+ grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para]
+ grouped_params = itertools.chain(*grouped_params)
+ pg_dict = {'params':grouped_params, 'name':group_name}
+ params.append(pg_dict)
+ else:
+ params = net.parameters()
+ return self.optimizer[t](params, lr=0, **cfg.args)
diff --git a/lib/model_zoo/common/get_scheduler.py b/lib/model_zoo/common/get_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd7c86e89dd9fcd092836546555b14cb68c7771d
--- /dev/null
+++ b/lib/model_zoo/common/get_scheduler.py
@@ -0,0 +1,262 @@
+import torch
+import torch.optim as optim
+import numpy as np
+import copy
+from ... import sync
+from ...cfg_holder import cfg_unique_holder as cfguh
+
+def singleton(class_):
+ instances = {}
+ def getinstance(*args, **kwargs):
+ if class_ not in instances:
+ instances[class_] = class_(*args, **kwargs)
+ return instances[class_]
+ return getinstance
+
+@singleton
+class get_scheduler(object):
+ def __init__(self):
+ self.lr_scheduler = {}
+
+ def register(self, lrsf, name):
+ self.lr_scheduler[name] = lrsf
+
+ def __call__(self, cfg):
+ if cfg is None:
+ return None
+ if isinstance(cfg, list):
+ schedulers = []
+ for ci in cfg:
+ t = ci.type
+ schedulers.append(
+ self.lr_scheduler[t](**ci.args))
+ if len(schedulers) == 0:
+ raise ValueError
+ else:
+ return compose_scheduler(schedulers)
+ t = cfg.type
+ return self.lr_scheduler[t](**cfg.args)
+
+
+def register(name):
+ def wrapper(class_):
+ get_scheduler().register(class_, name)
+ return class_
+ return wrapper
+
+class template_scheduler(object):
+ def __init__(self, step):
+ self.step = step
+
+ def __getitem__(self, idx):
+ raise ValueError
+
+ def set_lr(self, optim, new_lr, pg_lrscale=None):
+ """
+ Set Each parameter_groups in optim with new_lr
+ New_lr can be find according to the idx.
+ pg_lrscale tells how to scale each pg.
+ """
+ # new_lr = self.__getitem__(idx)
+ pg_lrscale = copy.deepcopy(pg_lrscale)
+ for pg in optim.param_groups:
+ if pg_lrscale is None:
+ pg['lr'] = new_lr
+ else:
+ pg['lr'] = new_lr * pg_lrscale.pop(pg['name'])
+ assert (pg_lrscale is None) or (len(pg_lrscale)==0), \
+ "pg_lrscale doesn't match pg"
+
+@register('constant')
+class constant_scheduler(template_scheduler):
+ def __init__(self, lr, step):
+ super().__init__(step)
+ self.lr = lr
+
+ def __getitem__(self, idx):
+ if idx >= self.step:
+ raise ValueError
+ return self.lr
+
+@register('poly')
+class poly_scheduler(template_scheduler):
+ def __init__(self, start_lr, end_lr, power, step):
+ super().__init__(step)
+ self.start_lr = start_lr
+ self.end_lr = end_lr
+ self.power = power
+
+ def __getitem__(self, idx):
+ if idx >= self.step:
+ raise ValueError
+ a, b = self.start_lr, self.end_lr
+ p, n = self.power, self.step
+ return b + (a-b)*((1-idx/n)**p)
+
+@register('linear')
+class linear_scheduler(template_scheduler):
+ def __init__(self, start_lr, end_lr, step):
+ super().__init__(step)
+ self.start_lr = start_lr
+ self.end_lr = end_lr
+
+ def __getitem__(self, idx):
+ if idx >= self.step:
+ raise ValueError
+ a, b, n = self.start_lr, self.end_lr, self.step
+ return b + (a-b)*(1-idx/n)
+
+@register('multistage')
+class constant_scheduler(template_scheduler):
+ def __init__(self, start_lr, milestones, gamma, step):
+ super().__init__(step)
+ self.start_lr = start_lr
+ m = [0] + milestones + [step]
+ lr_iter = start_lr
+ self.lr = []
+ for ms, me in zip(m[0:-1], m[1:]):
+ for _ in range(ms, me):
+ self.lr.append(lr_iter)
+ lr_iter *= gamma
+
+ def __getitem__(self, idx):
+ if idx >= self.step:
+ raise ValueError
+ return self.lr[idx]
+
+class compose_scheduler(template_scheduler):
+ def __init__(self, schedulers):
+ self.schedulers = schedulers
+ self.step = [si.step for si in schedulers]
+ self.step_milestone = []
+ acc = 0
+ for i in self.step:
+ acc += i
+ self.step_milestone.append(acc)
+ self.step = sum(self.step)
+
+ def __getitem__(self, idx):
+ if idx >= self.step:
+ raise ValueError
+ ms = self.step_milestone
+ for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])):
+ if mi <= idx < mj:
+ return self.schedulers[idx-mi]
+ raise ValueError
+
+####################
+# lambda schedular #
+####################
+
+class LambdaWarmUpCosineScheduler(template_scheduler):
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self,
+ base_lr,
+ warm_up_steps,
+ lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ cfgt = cfguh().cfg.train
+ bs = cfgt.batch_size
+ if 'gradacc_every' not in cfgt:
+ print('Warning, gradacc_every is not found in xml, use 1 as default.')
+ acc = cfgt.get('gradacc_every', 1)
+ self.lr_multi = base_lr * bs * acc
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __getitem__(self, idx):
+ return self.schedule(idx) * self.lr_multi
+
+class LambdaWarmUpCosineScheduler2(template_scheduler):
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+ def __init__(self,
+ base_lr,
+ warm_up_steps,
+ f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
+ cfgt = cfguh().cfg.train
+ # bs = cfgt.batch_size
+ # if 'gradacc_every' not in cfgt:
+ # print('Warning, gradacc_every is not found in xml, use 1 as default.')
+ # acc = cfgt.get('gradacc_every', 1)
+ # self.lr_multi = base_lr * bs * acc
+ self.lr_multi = base_lr
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi))
+ self.last_f = f
+ return f
+
+ def __getitem__(self, idx):
+ return self.schedule(idx) * self.lr_multi
+
+@register('stable_diffusion_linear')
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+ def schedule(self, n):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
\ No newline at end of file
diff --git a/lib/model_zoo/common/utils.py b/lib/model_zoo/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9979e0bc09de2bf3251c651434d7acd2f7305b96
--- /dev/null
+++ b/lib/model_zoo/common/utils.py
@@ -0,0 +1,292 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import copy
+import functools
+import itertools
+
+import matplotlib.pyplot as plt
+
+########
+# unit #
+########
+
+def singleton(class_):
+ instances = {}
+ def getinstance(*args, **kwargs):
+ if class_ not in instances:
+ instances[class_] = class_(*args, **kwargs)
+ return instances[class_]
+ return getinstance
+
+def str2value(v):
+ v = v.strip()
+ try:
+ return int(v)
+ except:
+ pass
+ try:
+ return float(v)
+ except:
+ pass
+ if v in ('True', 'true'):
+ return True
+ elif v in ('False', 'false'):
+ return False
+ else:
+ return v
+
+@singleton
+class get_unit(object):
+ def __init__(self):
+ self.unit = {}
+ self.register('none', None)
+
+ # general convolution
+ self.register('conv' , nn.Conv2d)
+ self.register('bn' , nn.BatchNorm2d)
+ self.register('relu' , nn.ReLU)
+ self.register('relu6' , nn.ReLU6)
+ self.register('lrelu' , nn.LeakyReLU)
+ self.register('dropout' , nn.Dropout)
+ self.register('dropout2d', nn.Dropout2d)
+ self.register('sine', Sine)
+ self.register('relusine', ReLUSine)
+
+ def register(self,
+ name,
+ unitf,):
+
+ self.unit[name] = unitf
+
+ def __call__(self, name):
+ if name is None:
+ return None
+ i = name.find('(')
+ i = len(name) if i==-1 else i
+ t = name[:i]
+ f = self.unit[t]
+ args = name[i:].strip('()')
+ if len(args) == 0:
+ args = {}
+ return f
+ else:
+ args = args.split('=')
+ args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
+ args = list(itertools.chain.from_iterable(args))
+ args = [i.strip() for i in args if len(i)>0]
+ kwargs = {}
+ for k, v in zip(args[::2], args[1::2]):
+ if v[0]=='(' and v[-1]==')':
+ kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
+ elif v[0]=='[' and v[-1]==']':
+ kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
+ else:
+ kwargs[k] = str2value(v)
+ return functools.partial(f, **kwargs)
+
+def register(name):
+ def wrapper(class_):
+ get_unit().register(name, class_)
+ return class_
+ return wrapper
+
+class Sine(object):
+ def __init__(self, freq, gain=1):
+ self.freq = freq
+ self.gain = gain
+ self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
+
+ def __call__(self, x, gain=1):
+ act_gain = self.gain * gain
+ return torch.sin(self.freq * x) * act_gain
+
+ def __repr__(self,):
+ return self.repr
+
+class ReLUSine(nn.Module):
+ def __init(self):
+ super().__init__()
+
+ def forward(self, input):
+ a = torch.sin(30 * input)
+ b = nn.ReLU(inplace=False)(input)
+ return a+b
+
+@register('lrelu_agc')
+# class lrelu_agc(nn.Module):
+class lrelu_agc(object):
+ """
+ The lrelu layer with alpha, gain and clamp
+ """
+ def __init__(self, alpha=0.1, gain=1, clamp=None):
+ # super().__init__()
+ self.alpha = alpha
+ if gain == 'sqrt_2':
+ self.gain = np.sqrt(2)
+ else:
+ self.gain = gain
+ self.clamp = clamp
+ self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
+ alpha, gain, clamp)
+
+ # def forward(self, x, gain=1):
+ def __call__(self, x, gain=1):
+ x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
+ act_gain = self.gain * gain
+ act_clamp = self.clamp * gain if self.clamp is not None else None
+ if act_gain != 1:
+ x = x * act_gain
+ if act_clamp is not None:
+ x = x.clamp(-act_clamp, act_clamp)
+ return x
+
+ def __repr__(self,):
+ return self.repr
+
+####################
+# spatial encoding #
+####################
+
+@register('se')
+class SpatialEncoding(nn.Module):
+ def __init__(self,
+ in_dim,
+ out_dim,
+ sigma = 6,
+ cat_input=True,
+ require_grad=False,):
+
+ super().__init__()
+ assert out_dim % (2*in_dim) == 0, "dimension must be dividable"
+
+ n = out_dim // 2 // in_dim
+ m = 2**np.linspace(0, sigma, n)
+ m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1)
+ m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
+ self.emb = torch.FloatTensor(m)
+ if require_grad:
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.sigma = sigma
+ self.cat_input = cat_input
+ self.require_grad = require_grad
+
+ def forward(self, x, format='[n x c]'):
+ """
+ Args:
+ x: [n x m1],
+ m1 usually is 2
+ Outputs:
+ y: [n x m2]
+ m2 dimention number
+ """
+ if format == '[bs x c x 2D]':
+ xshape = x.shape
+ x = x.permute(0, 2, 3, 1).contiguous()
+ x = x.view(-1, x.size(-1))
+ elif format == '[n x c]':
+ pass
+ else:
+ raise ValueError
+
+ if not self.require_grad:
+ self.emb = self.emb.to(x.device)
+ y = torch.mm(x, self.emb.T)
+ if self.cat_input:
+ z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
+ else:
+ z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
+
+ if format == '[bs x c x 2D]':
+ z = z.view(xshape[0], xshape[2], xshape[3], -1)
+ z = z.permute(0, 3, 1, 2).contiguous()
+ return z
+
+ def extra_repr(self):
+ outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
+ return outstr
+
+@register('rffe')
+class RFFEncoding(SpatialEncoding):
+ """
+ Random Fourier Features
+ """
+ def __init__(self,
+ in_dim,
+ out_dim,
+ sigma = 6,
+ cat_input=True,
+ require_grad=False,):
+
+ super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
+ n = out_dim // 2
+ m = np.random.normal(0, sigma, size=(n, in_dim))
+ self.emb = torch.FloatTensor(m)
+ if require_grad:
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
+
+ def extra_repr(self):
+ outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
+ return outstr
+
+##########
+# helper #
+##########
+
+def freeze(net):
+ for m in net.modules():
+ if isinstance(m, (
+ nn.BatchNorm2d,
+ nn.SyncBatchNorm,)):
+ # inplace_abn not supported
+ m.eval()
+ for pi in net.parameters():
+ pi.requires_grad = False
+ return net
+
+def common_init(m):
+ if isinstance(m, (
+ nn.Conv2d,
+ nn.ConvTranspose2d,)):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, (
+ nn.BatchNorm2d,
+ nn.SyncBatchNorm,)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ else:
+ pass
+
+def init_module(module):
+ """
+ Args:
+ module: [nn.module] list or nn.module
+ a list of module to be initialized.
+ """
+ if isinstance(module, (list, tuple)):
+ module = list(module)
+ else:
+ module = [module]
+
+ for mi in module:
+ for mii in mi.modules():
+ common_init(mii)
+
+def get_total_param(net):
+ if getattr(net, 'parameters', None) is None:
+ return 0
+ return sum(p.numel() for p in net.parameters())
+
+def get_total_param_sum(net):
+ if getattr(net, 'parameters', None) is None:
+ return 0
+ with torch.no_grad():
+ s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
+ return s
diff --git a/lib/model_zoo/controlnet.py b/lib/model_zoo/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..facfa80401e9bd6ba6b9abc3467838cf0a266052
--- /dev/null
+++ b/lib/model_zoo/controlnet.py
@@ -0,0 +1,503 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import numpy.random as npr
+import copy
+from functools import partial
+from contextlib import contextmanager
+from lib.model_zoo.common.get_model import get_model, register
+from lib.log_service import print_log
+
+from .openaimodel import \
+ TimestepEmbedSequential, conv_nd, zero_module, \
+ ResBlock, AttentionBlock, SpatialTransformer, \
+ Downsample, timestep_embedding
+
+####################
+# preprocess depth #
+####################
+
+# depth_model = None
+
+# def unload_midas_model():
+# global depth_model
+# if depth_model is not None:
+# depth_model = depth_model.cpu()
+
+# def apply_midas(input_image, a=np.pi*2.0, bg_th=0.1, device='cpu'):
+# import cv2
+# from einops import rearrange
+# from .controlnet_annotators.midas import MiDaSInference
+# global depth_model
+# if depth_model is None:
+# depth_model = MiDaSInference(model_type="dpt_hybrid")
+# depth_model = depth_model.to(device)
+
+# assert input_image.ndim == 3
+# image_depth = input_image
+# with torch.no_grad():
+# image_depth = torch.from_numpy(image_depth).float()
+# image_depth = image_depth.to(device)
+# image_depth = image_depth / 127.5 - 1.0
+# image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
+# depth = depth_model(image_depth)[0]
+
+# depth_pt = depth.clone()
+# depth_pt -= torch.min(depth_pt)
+# depth_pt /= torch.max(depth_pt)
+# depth_pt = depth_pt.cpu().numpy()
+# depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
+
+# depth_np = depth.cpu().numpy()
+# x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
+# y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
+# z = np.ones_like(x) * a
+# x[depth_pt < bg_th] = 0
+# y[depth_pt < bg_th] = 0
+# normal = np.stack([x, y, z], axis=2)
+# normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
+# normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
+
+# return depth_image, normal_image
+
+
+@register('controlnet')
+class ControlNet(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ hint_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.dims = dims
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+
+ self.input_hint_block = TimestepEmbedSequential(
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 32, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ conv_nd(dims, 96, 96, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
+ )
+
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if disable_self_attentions is not None:
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (num_attention_blocks is None) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.zero_convs.append(self.make_zero_conv(ch))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ self.zero_convs.append(self.make_zero_conv(ch))
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self.middle_block_out = self.make_zero_conv(ch)
+ self._feature_size += ch
+
+ def make_zero_conv(self, channels):
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+
+ def forward(self, x, hint, timesteps, context, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ t_emb = t_emb.to(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ guided_hint = self.input_hint_block(hint, emb, context)
+
+ outs = []
+
+ h = x
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ if guided_hint is not None:
+ h = module(h, emb, context)
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+ outs.append(zero_conv(h, emb, context))
+
+ h = self.middle_block(h, emb, context)
+ outs.append(self.middle_block_out(h, emb, context))
+
+ return outs
+
+ def get_device(self):
+ return self.time_embed[0].weight.device
+
+ def get_dtype(self):
+ return self.time_embed[0].weight.dtype
+
+ def preprocess(self, x, type='canny', **kwargs):
+ import torchvision.transforms as tvtrans
+ if isinstance(x, str):
+ import PIL.Image
+ device, dtype = self.get_device(), self.get_dtype()
+ x_list = [PIL.Image.open(x)]
+ elif isinstance(x, torch.Tensor):
+ x_list = [tvtrans.ToPILImage()(xi) for xi in x]
+ device, dtype = x.device, x.dtype
+ else:
+ assert False
+
+ if type == 'none' or type is None:
+ return None
+
+ elif type in ['input', 'shuffle_v11e']:
+ y_torch = torch.stack([tvtrans.ToTensor()(xi) for xi in x_list])
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type in ['canny', 'canny_v11p']:
+ low_threshold = kwargs.pop('low_threshold', 100)
+ high_threshold = kwargs.pop('high_threshold', 200)
+ from .controlnet_annotator.canny import apply_canny
+ y_list = [apply_canny(np.array(xi), low_threshold, high_threshold) for xi in x_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type == 'depth':
+ from .controlnet_annotator.midas import apply_midas
+ y_list, _ = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, device=device) for xi in x_list])
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type in ['hed', 'softedge_v11p']:
+ from .controlnet_annotator.hed import apply_hed
+ y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type in ['mlsd', 'mlsd_v11p']:
+ thr_v = kwargs.pop('thr_v', 0.1)
+ thr_d = kwargs.pop('thr_d', 0.1)
+ from .controlnet_annotator.mlsd import apply_mlsd
+ y_list = [apply_mlsd(np.array(xi), thr_v=thr_v, thr_d=thr_d, device=device) for xi in x_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type == 'normal':
+ bg_th = kwargs.pop('bg_th', 0.4)
+ from .controlnet_annotator.midas import apply_midas
+ _, y_list = zip(*[apply_midas(input_image=np.array(xi), a=np.pi*2.0, bg_th=bg_th, device=device) for xi in x_list])
+ y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type in ['openpose', 'openpose_v11p']:
+ from .controlnet_annotator.openpose import OpenposeModel
+ from functools import partial
+ wrapper = OpenposeModel()
+ apply_openpose = partial(
+ wrapper.run_model, include_body=True, include_hand=False, include_face=False,
+ json_pose_callback=None, device=device)
+ y_list = [apply_openpose(np.array(xi)) for xi in x_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type in ['openpose_withface', 'openpose_withface_v11p']:
+ from .controlnet_annotator.openpose import OpenposeModel
+ from functools import partial
+ wrapper = OpenposeModel()
+ apply_openpose = partial(
+ wrapper.run_model, include_body=True, include_hand=False, include_face=True,
+ json_pose_callback=None, device=device)
+ y_list = [apply_openpose(np.array(xi)) for xi in x_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type in ['openpose_withfacehand', 'openpose_withfacehand_v11p']:
+ from .controlnet_annotator.openpose import OpenposeModel
+ from functools import partial
+ wrapper = OpenposeModel()
+ apply_openpose = partial(
+ wrapper.run_model, include_body=True, include_hand=True, include_face=True,
+ json_pose_callback=None, device=device)
+ y_list = [apply_openpose(np.array(xi)) for xi in x_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi.copy()) for yi in y_list])
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif type == 'scribble':
+ method = kwargs.pop('method', 'pidinet')
+
+ import cv2
+ def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+ y = np.zeros_like(x)
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
+
+ def make_scribble(result):
+ result = nms(result, 127, 3.0)
+ result = cv2.GaussianBlur(result, (0, 0), 3.0)
+ result[result > 4] = 255
+ result[result < 255] = 0
+ return result
+
+ if method == 'hed':
+ from .controlnet_annotator.hed import apply_hed
+ y_list = [apply_hed(np.array(xi), device=device) for xi in x_list]
+ y_list = [make_scribble(yi) for yi in y_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif method == 'pidinet':
+ from .controlnet_annotator.pidinet import apply_pidinet
+ y_list = [apply_pidinet(np.array(xi), device=device) for xi in x_list]
+ y_list = [make_scribble(yi) for yi in y_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ elif method == 'xdog':
+ threshold = kwargs.pop('threshold', 32)
+ def apply_scribble_xdog(img):
+ g1 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 0.5)
+ g2 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 5.0)
+ dog = (255 - np.min(g2 - g1, axis=2)).clip(0, 255).astype(np.uint8)
+ result = np.zeros_like(img, dtype=np.uint8)
+ result[2 * (255 - dog) > threshold] = 255
+ return result
+
+ y_list = [apply_scribble_xdog(np.array(xi), device=device) for xi in x_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
+ y_torch = y_torch.repeat(1, 3, 1, 1) # Make is RGB
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ else:
+ raise ValueError
+
+ elif type == 'seg':
+ method = kwargs.pop('method', 'ufade20k')
+ if method == 'ufade20k':
+ from .controlnet_annotator.uniformer import apply_uniformer
+ y_list = [apply_uniformer(np.array(xi), palette='ade20k', device=device) for xi in x_list]
+ y_torch = torch.stack([tvtrans.ToTensor()(yi) for yi in y_list])
+ y_torch = y_torch.to(device).to(torch.float32)
+ return y_torch
+
+ else:
+ raise ValueError
diff --git a/lib/model_zoo/controlnet_annotator/canny/__init__.py b/lib/model_zoo/controlnet_annotator/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace985839d3fc18dd4947f6c38e9f5d5a2625aca
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/canny/__init__.py
@@ -0,0 +1,5 @@
+import cv2
+
+
+def apply_canny(img, low_threshold, high_threshold):
+ return cv2.Canny(img, low_threshold, high_threshold)
diff --git a/lib/model_zoo/controlnet_annotator/hed/__init__.py b/lib/model_zoo/controlnet_annotator/hed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dcfaeb55887ef0fa0928e817e2ac46de1669923
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/hed/__init__.py
@@ -0,0 +1,134 @@
+# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
+# Please use this implementation in your products
+# This implementation may produce slightly different results from Saining Xie's official implementations,
+# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
+# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
+# and in this way it works better for gradio's RGB protocol
+
+import os
+import cv2
+import torch
+import numpy as np
+
+from einops import rearrange
+import os
+
+models_path = 'pretrained/controlnet/preprocess'
+
+def safe_step(x, step=2):
+ y = x.astype(np.float32) * float(step + 1)
+ y = y.astype(np.int32).astype(np.float32) / float(step)
+ return y
+
+class DoubleConvBlock(torch.nn.Module):
+ def __init__(self, input_channel, output_channel, layer_number):
+ super().__init__()
+ self.convs = torch.nn.Sequential()
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
+ for i in range(1, layer_number):
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
+
+ def __call__(self, x, down_sampling=False):
+ h = x
+ if down_sampling:
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
+ for conv in self.convs:
+ h = conv(h)
+ h = torch.nn.functional.relu(h)
+ return h, self.projection(h)
+
+
+class ControlNetHED_Apache2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
+
+ def __call__(self, x):
+ h = x - self.norm
+ h, projection1 = self.block1(h)
+ h, projection2 = self.block2(h, down_sampling=True)
+ h, projection3 = self.block3(h, down_sampling=True)
+ h, projection4 = self.block4(h, down_sampling=True)
+ h, projection5 = self.block5(h, down_sampling=True)
+ return projection1, projection2, projection3, projection4, projection5
+
+
+netNetwork = None
+remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
+modeldir = os.path.join(models_path, "hed")
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+ Returns:
+ str: The path to the downloaded file.
+ """
+ from torch.hub import download_url_to_file, get_dir
+ from urllib.parse import urlparse
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+
+def apply_hed(input_image, is_safe=False, device='cpu'):
+ global netNetwork
+ if netNetwork is None:
+ modelpath = os.path.join(modeldir, "ControlNetHED.pth")
+ old_modelpath = os.path.join(old_modeldir, "ControlNetHED.pth")
+ if os.path.exists(old_modelpath):
+ modelpath = old_modelpath
+ elif not os.path.exists(modelpath):
+ load_file_from_url(remote_model_path, model_dir=modeldir)
+ netNetwork = ControlNetHED_Apache2().to(device)
+ netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
+ netNetwork.to(device).float().eval()
+
+ assert input_image.ndim == 3
+ H, W, C = input_image.shape
+ with torch.no_grad():
+ image_hed = torch.from_numpy(input_image.copy()).float().to(device)
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+ edges = netNetwork(image_hed)
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
+ edges = np.stack(edges, axis=2)
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
+ if is_safe:
+ edge = safe_step(edge)
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
+ return edge
+
+
+def unload_hed_model():
+ global netNetwork
+ if netNetwork is not None:
+ netNetwork.cpu()
diff --git a/lib/model_zoo/controlnet_annotator/midas/LICENSE b/lib/model_zoo/controlnet_annotator/midas/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/lib/model_zoo/controlnet_annotator/midas/__init__.py b/lib/model_zoo/controlnet_annotator/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e812da1bdb25fdde5082e33ce5f40e402b379de3
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/__init__.py
@@ -0,0 +1,46 @@
+import cv2
+import numpy as np
+import torch
+
+from einops import rearrange
+from .api import MiDaSInference
+
+model = None
+
+def unload_midas_model():
+ global model
+ if model is not None:
+ model = model.cpu()
+
+def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1, device='cpu'):
+ global model
+ if model is None:
+ model = MiDaSInference(model_type="dpt_hybrid")
+ model = model.to(device)
+
+ assert input_image.ndim == 3
+ image_depth = input_image
+ with torch.no_grad():
+ image_depth = torch.from_numpy(image_depth).float()
+ image_depth = image_depth.to(device)
+ image_depth = image_depth / 127.5 - 1.0
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
+ depth = model(image_depth)[0]
+
+ depth_pt = depth.clone()
+ depth_pt -= torch.min(depth_pt)
+ depth_pt /= torch.max(depth_pt)
+ depth_pt = depth_pt.cpu().numpy()
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
+
+ depth_np = depth.cpu().numpy()
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
+ z = np.ones_like(x) * a
+ x[depth_pt < bg_th] = 0
+ y[depth_pt < bg_th] = 0
+ normal = np.stack([x, y, z], axis=2)
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1]
+
+ return depth_image, normal_image
diff --git a/lib/model_zoo/controlnet_annotator/midas/api.py b/lib/model_zoo/controlnet_annotator/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f63b253958bc232a813c98e997f7161594a0b88
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/api.py
@@ -0,0 +1,214 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import torch
+import torch.nn as nn
+import os
+models_path = 'pretrained/controlnet/preprocess'
+
+from torchvision.transforms import Compose
+
+from .midas.dpt_depth import DPTDepthModel
+from .midas.midas_net import MidasNet
+from .midas.midas_net_custom import MidasNet_small
+from .midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+base_model_path = os.path.join(models_path, "midas")
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
+
+ISL_PATHS = {
+ "dpt_large": os.path.join(base_model_path, "dpt_large-midas-2f21e586.pt"),
+ "dpt_hybrid": os.path.join(base_model_path, "dpt_hybrid-midas-501f0c75.pt"),
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+OLD_ISL_PATHS = {
+ "dpt_large": os.path.join(old_modeldir, "dpt_large-midas-2f21e586.pt"),
+ "dpt_hybrid": os.path.join(old_modeldir, "dpt_hybrid-midas-501f0c75.pt"),
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return transform
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+ Returns:
+ str: The path to the downloaded file.
+ """
+ from torch.hub import download_url_to_file, get_dir
+ from urllib.parse import urlparse
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ old_model_path = OLD_ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ if os.path.exists(old_model_path):
+ model_path = old_model_path
+ elif not os.path.exists(model_path):
+ load_file_from_url(remote_model_path, model_dir=base_model_path)
+
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+ MODEL_TYPES_TORCH_HUB = [
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ MODEL_TYPES_ISL = [
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+
+ def forward(self, x):
+ with torch.no_grad():
+ prediction = self.model(x)
+ return prediction
+
diff --git a/lib/model_zoo/controlnet_annotator/midas/midas/__init__.py b/lib/model_zoo/controlnet_annotator/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/model_zoo/controlnet_annotator/midas/midas/base_model.py b/lib/model_zoo/controlnet_annotator/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/lib/model_zoo/controlnet_annotator/midas/midas/blocks.py b/lib/model_zoo/controlnet_annotator/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/lib/model_zoo/controlnet_annotator/midas/midas/dpt_depth.py b/lib/model_zoo/controlnet_annotator/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
+
diff --git a/lib/model_zoo/controlnet_annotator/midas/midas/midas_net.py b/lib/model_zoo/controlnet_annotator/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/lib/model_zoo/controlnet_annotator/midas/midas/midas_net_custom.py b/lib/model_zoo/controlnet_annotator/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/midas/midas/transforms.py b/lib/model_zoo/controlnet_annotator/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/lib/model_zoo/controlnet_annotator/midas/midas/vit.py b/lib/model_zoo/controlnet_annotator/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/lib/model_zoo/controlnet_annotator/midas/utils.py b/lib/model_zoo/controlnet_annotator/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/lib/model_zoo/controlnet_annotator/mlsd/LICENSE b/lib/model_zoo/controlnet_annotator/mlsd/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/mlsd/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021-present NAVER Corp.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/mlsd/__init__.py b/lib/model_zoo/controlnet_annotator/mlsd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c9ee57580777c413ad13a0e6888820a1c2770e0
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/mlsd/__init__.py
@@ -0,0 +1,81 @@
+import cv2
+import numpy as np
+import torch
+import os
+
+from einops import rearrange
+from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
+from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
+from .utils import pred_lines
+
+models_path = 'pretrained/controlnet/preprocess'
+
+mlsdmodel = None
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth"
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+modeldir = os.path.join(models_path, "mlsd")
+
+def unload_mlsd_model():
+ global mlsdmodel
+ if mlsdmodel is not None:
+ mlsdmodel = mlsdmodel.cpu()
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+ Returns:
+ str: The path to the downloaded file.
+ """
+ from torch.hub import download_url_to_file, get_dir
+ from urllib.parse import urlparse
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+def apply_mlsd(input_image, thr_v, thr_d, device='cpu'):
+ global modelpath, mlsdmodel
+ if mlsdmodel is None:
+ modelpath = os.path.join(modeldir, "mlsd_large_512_fp32.pth")
+ old_modelpath = os.path.join(old_modeldir, "mlsd_large_512_fp32.pth")
+ if os.path.exists(old_modelpath):
+ modelpath = old_modelpath
+ elif not os.path.exists(modelpath):
+ load_file_from_url(remote_model_path, model_dir=modeldir)
+ mlsdmodel = MobileV2_MLSD_Large()
+ mlsdmodel.load_state_dict(torch.load(modelpath), strict=True)
+ mlsdmodel = mlsdmodel.to(device).eval()
+
+ model = mlsdmodel
+ assert input_image.ndim == 3
+ img = input_image
+ img_output = np.zeros_like(img)
+ try:
+ with torch.no_grad():
+ lines = pred_lines(img, model, [img.shape[0], img.shape[1]], thr_v, thr_d)
+ for line in lines:
+ x_start, y_start, x_end, y_end = [int(val) for val in line]
+ cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
+ except Exception as e:
+ pass
+ return img_output[:, :, 0]
diff --git a/lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_large.py b/lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_large.py
@@ -0,0 +1,292 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ if self.upscale:
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+
+ self.features = nn.Sequential(*features)
+ self.fpn_selected = [1, 3, 6, 10, 13]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+ if pretrained:
+ self._load_pretrained_model()
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+
+ c1, c2, c3, c4, c5 = fpn_features
+ return c1, c2, c3, c4, c5
+
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Large(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Large, self).__init__()
+
+ self.backbone = MobileNetV2(pretrained=False)
+ ## A, B
+ self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
+ out_c1= 64, out_c2=64,
+ upscale=False)
+ self.block16 = BlockTypeB(128, 64)
+
+ ## A, B
+ self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
+ out_c1= 64, out_c2= 64)
+ self.block18 = BlockTypeB(128, 64)
+
+ ## A, B
+ self.block19 = BlockTypeA(in_c1=24, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block20 = BlockTypeB(128, 64)
+
+ ## A, B, C
+ self.block21 = BlockTypeA(in_c1=16, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block22 = BlockTypeB(128, 64)
+
+ self.block23 = BlockTypeC(64, 16)
+
+ def forward(self, x):
+ c1, c2, c3, c4, c5 = self.backbone(x)
+
+ x = self.block15(c4, c5)
+ x = self.block16(x)
+
+ x = self.block17(c3, x)
+ x = self.block18(x)
+
+ x = self.block19(c2, x)
+ x = self.block20(x)
+
+ x = self.block21(c1, x)
+ x = self.block22(x)
+ x = self.block23(x)
+ x = x[:, 7:, :, :]
+
+ return x
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_tiny.py b/lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/mlsd/models/mbv2_mlsd_tiny.py
@@ -0,0 +1,275 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ #[6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+ self.features = nn.Sequential(*features)
+
+ self.fpn_selected = [3, 6, 10]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ #if pretrained:
+ # self._load_pretrained_model()
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+
+ c2, c3, c4 = fpn_features
+ return c2, c3, c4
+
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Tiny(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Tiny, self).__init__()
+
+ self.backbone = MobileNetV2(pretrained=True)
+
+ self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
+ out_c1= 64, out_c2=64)
+ self.block13 = BlockTypeB(128, 64)
+
+ self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
+ out_c1= 32, out_c2= 32)
+ self.block15 = BlockTypeB(64, 64)
+
+ self.block16 = BlockTypeC(64, 16)
+
+ def forward(self, x):
+ c2, c3, c4 = self.backbone(x)
+
+ x = self.block12(c3, c4)
+ x = self.block13(x)
+ x = self.block14(c2, x)
+ x = self.block15(x)
+ x = self.block16(x)
+ x = x[:, 7:, :, :]
+ #print(x.shape)
+ x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
+
+ return x
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/mlsd/utils.py b/lib/model_zoo/controlnet_annotator/mlsd/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dfa5c1c2d6a6f7d96bbe545ca182fefbe6cb3bd
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/mlsd/utils.py
@@ -0,0 +1,582 @@
+'''
+modified by lihaoweicv
+pytorch version
+'''
+
+'''
+M-LSD
+Copyright 2021-present NAVER Corp.
+Apache License v2.0
+'''
+
+import os
+import numpy as np
+import cv2
+import torch
+from torch.nn import functional as F
+
+def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
+ '''
+ tpMap:
+ center: tpMap[1, 0, :, :]
+ displacement: tpMap[1, 1:5, :, :]
+ '''
+ b, c, h, w = tpMap.shape
+ assert b==1, 'only support bsize==1'
+ displacement = tpMap[:, 1:5, :, :][0]
+ center = tpMap[:, 0, :, :]
+ heat = torch.sigmoid(center)
+ hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
+ keep = (hmax == heat).float()
+ heat = heat * keep
+ heat = heat.reshape(-1, )
+
+ scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
+ yy = torch.floor_divide(indices, w).unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ ptss = torch.cat((yy, xx),dim=-1)
+
+ ptss = ptss.detach().cpu().numpy()
+ scores = scores.detach().cpu().numpy()
+ displacement = displacement.detach().cpu().numpy()
+ displacement = displacement.transpose((1,2,0))
+ return ptss, scores, displacement
+
+def get_device(model):
+ # A hack function find the device of network
+ return model.block22.conv2[0].weight.device
+
+def pred_lines(image, model,
+ input_shape=[512, 512],
+ score_thr=0.10,
+ dist_thr=20.0):
+ h, w, _ = image.shape
+ h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
+
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+
+ resized_image = resized_image.transpose((2,0,1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+
+ batch_image = torch.from_numpy(batch_image).float().to(get_device(model))
+ outputs = model(batch_image)
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2]
+ end = vmap[:, :, 2:]
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+ segments_list = []
+ for center, score in zip(pts, pts_score):
+ y, x = center
+ distance = dist_map[y, x]
+ if score > score_thr and distance > dist_thr:
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ x_start = x + disp_x_start
+ y_start = y + disp_y_start
+ x_end = x + disp_x_end
+ y_end = y + disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+
+ lines = 2 * np.array(segments_list) # 256 > 512
+ lines[:, 0] = lines[:, 0] * w_ratio
+ lines[:, 1] = lines[:, 1] * h_ratio
+ lines[:, 2] = lines[:, 2] * w_ratio
+ lines[:, 3] = lines[:, 3] * h_ratio
+
+ return lines
+
+
+def pred_squares(image,
+ model,
+ input_shape=[512, 512],
+ params={'score': 0.06,
+ 'outside_ratio': 0.28,
+ 'inside_ratio': 0.45,
+ 'w_overlap': 0.0,
+ 'w_degree': 1.95,
+ 'w_length': 0.0,
+ 'w_area': 1.86,
+ 'w_center': 0.14}):
+ '''
+ shape = [height, width]
+ '''
+ h, w, _ = image.shape
+ original_shape = [h, w]
+
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+ resized_image = resized_image.transpose((2, 0, 1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+
+ batch_image = torch.from_numpy(batch_image).float().to(get_device(model))
+ outputs = model(batch_image)
+
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2] # (x, y)
+ end = vmap[:, :, 2:] # (x, y)
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+ junc_list = []
+ segments_list = []
+ for junc, score in zip(pts, pts_score):
+ y, x = junc
+ distance = dist_map[y, x]
+ if score > params['score'] and distance > 20.0:
+ junc_list.append([x, y])
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ d_arrow = 1.0
+ x_start = x + d_arrow * disp_x_start
+ y_start = y + d_arrow * disp_y_start
+ x_end = x + d_arrow * disp_x_end
+ y_end = y + d_arrow * disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+
+ segments = np.array(segments_list)
+
+ ####### post processing for squares
+ # 1. get unique lines
+ point = np.array([[0, 0]])
+ point = point[0]
+ start = segments[:, :2]
+ end = segments[:, 2:]
+ diff = start - end
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+
+ d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
+ theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
+ theta[theta < 0.0] += 180
+ hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
+
+ d_quant = 1
+ theta_quant = 2
+ hough[:, 0] //= d_quant
+ hough[:, 1] //= theta_quant
+ _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
+
+ acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
+ idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
+ yx_indices = hough[indices, :].astype('int32')
+ acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
+ idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
+
+ acc_map_np = acc_map
+ # acc_map = acc_map[None, :, :, None]
+ #
+ # ### fast suppression using tensorflow op
+ # acc_map = tf.constant(acc_map, dtype=tf.float32)
+ # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
+ # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
+ # flatten_acc_map = tf.reshape(acc_map, [1, -1])
+ # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
+ # _, h, w, _ = acc_map.shape
+ # y = tf.expand_dims(topk_indices // w, axis=-1)
+ # x = tf.expand_dims(topk_indices % w, axis=-1)
+ # yx = tf.concat([y, x], axis=-1)
+
+ ### fast suppression using pytorch op
+ acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
+ _,_, h, w = acc_map.shape
+ max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
+ acc_map = acc_map * ( (acc_map == max_acc_map).float() )
+ flatten_acc_map = acc_map.reshape([-1, ])
+
+ scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
+ yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ yx = torch.cat((yy, xx), dim=-1)
+
+ yx = yx.detach().cpu().numpy()
+
+ topk_values = scores.detach().cpu().numpy()
+ indices = idx_map[yx[:, 0], yx[:, 1]]
+ basis = 5 // 2
+
+ merged_segments = []
+ for yx_pt, max_indice, value in zip(yx, indices, topk_values):
+ y, x = yx_pt
+ if max_indice == -1 or value == 0:
+ continue
+ segment_list = []
+ for y_offset in range(-basis, basis + 1):
+ for x_offset in range(-basis, basis + 1):
+ indice = idx_map[y + y_offset, x + x_offset]
+ cnt = int(acc_map_np[y + y_offset, x + x_offset])
+ if indice != -1:
+ segment_list.append(segments[indice])
+ if cnt > 1:
+ check_cnt = 1
+ current_hough = hough[indice]
+ for new_indice, new_hough in enumerate(hough):
+ if (current_hough == new_hough).all() and indice != new_indice:
+ segment_list.append(segments[new_indice])
+ check_cnt += 1
+ if check_cnt == cnt:
+ break
+ group_segments = np.array(segment_list).reshape([-1, 2])
+ sorted_group_segments = np.sort(group_segments, axis=0)
+ x_min, y_min = sorted_group_segments[0, :]
+ x_max, y_max = sorted_group_segments[-1, :]
+
+ deg = theta[max_indice]
+ if deg >= 90:
+ merged_segments.append([x_min, y_max, x_max, y_min])
+ else:
+ merged_segments.append([x_min, y_min, x_max, y_max])
+
+ # 2. get intersections
+ new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
+ start = new_segments[:, :2] # (x1, y1)
+ end = new_segments[:, 2:] # (x2, y2)
+ new_centers = (start + end) / 2.0
+ diff = start - end
+ dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
+
+ # ax + by = c
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+ pre_det = a[:, None] * b[None, :]
+ det = pre_det - np.transpose(pre_det)
+
+ pre_inter_y = a[:, None] * c[None, :]
+ inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
+ pre_inter_x = c[:, None] * b[None, :]
+ inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
+ inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
+
+ # 3. get corner information
+ # 3.1 get distance
+ '''
+ dist_segments:
+ | dist(0), dist(1), dist(2), ...|
+ dist_inter_to_segment1:
+ | dist(inter,0), dist(inter,0), dist(inter,0), ... |
+ | dist(inter,1), dist(inter,1), dist(inter,1), ... |
+ ...
+ dist_inter_to_semgnet2:
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ ...
+ '''
+
+ dist_inter_to_segment1_start = np.sqrt(
+ np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment1_end = np.sqrt(
+ np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_start = np.sqrt(
+ np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_end = np.sqrt(
+ np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+
+ # sort ascending
+ dist_inter_to_segment1 = np.sort(
+ np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+ dist_inter_to_segment2 = np.sort(
+ np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+
+ # 3.2 get degree
+ inter_to_start = new_centers[:, None, :] - inter_pts
+ deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
+ deg_inter_to_start[deg_inter_to_start < 0.0] += 360
+ inter_to_end = new_centers[None, :, :] - inter_pts
+ deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
+ deg_inter_to_end[deg_inter_to_end < 0.0] += 360
+
+ '''
+ B -- G
+ | |
+ C -- R
+ B : blue / G: green / C: cyan / R: red
+
+ 0 -- 1
+ | |
+ 3 -- 2
+ '''
+ # rename variables
+ deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
+ # sort deg ascending
+ deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
+
+ deg_diff_map = np.abs(deg1_map - deg2_map)
+ # we only consider the smallest degree of intersect
+ deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
+
+ # define available degree range
+ deg_range = [60, 120]
+
+ corner_dict = {corner_info: [] for corner_info in range(4)}
+ inter_points = []
+ for i in range(inter_pts.shape[0]):
+ for j in range(i + 1, inter_pts.shape[1]):
+ # i, j > line index, always i < j
+ x, y = inter_pts[i, j, :]
+ deg1, deg2 = deg_sort[i, j, :]
+ deg_diff = deg_diff_map[i, j]
+
+ check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
+
+ outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
+ inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
+ check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
+ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
+ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
+ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
+
+ if check_degree and check_distance:
+ corner_info = None
+
+ if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
+ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
+ corner_info, color_info = 0, 'blue'
+ elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
+ corner_info, color_info = 1, 'green'
+ elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
+ corner_info, color_info = 2, 'black'
+ elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
+ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
+ corner_info, color_info = 3, 'cyan'
+ else:
+ corner_info, color_info = 4, 'red' # we don't use it
+ continue
+
+ corner_dict[corner_info].append([x, y, i, j])
+ inter_points.append([x, y])
+
+ square_list = []
+ connect_list = []
+ segments_list = []
+ for corner0 in corner_dict[0]:
+ for corner1 in corner_dict[1]:
+ connect01 = False
+ for corner0_line in corner0[2:]:
+ if corner0_line in corner1[2:]:
+ connect01 = True
+ break
+ if connect01:
+ for corner2 in corner_dict[2]:
+ connect12 = False
+ for corner1_line in corner1[2:]:
+ if corner1_line in corner2[2:]:
+ connect12 = True
+ break
+ if connect12:
+ for corner3 in corner_dict[3]:
+ connect23 = False
+ for corner2_line in corner2[2:]:
+ if corner2_line in corner3[2:]:
+ connect23 = True
+ break
+ if connect23:
+ for corner3_line in corner3[2:]:
+ if corner3_line in corner0[2:]:
+ # SQUARE!!!
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+ square_list:
+ order: 0 > 1 > 2 > 3
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ ...
+ connect_list:
+ order: 01 > 12 > 23 > 30
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ ...
+ segments_list:
+ order: 0 > 1 > 2 > 3
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ ...
+ '''
+ square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
+ connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
+ segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
+
+ def check_outside_inside(segments_info, connect_idx):
+ # return 'outside or inside', min distance, cover_param, peri_param
+ if connect_idx == segments_info[0]:
+ check_dist_mat = dist_inter_to_segment1
+ else:
+ check_dist_mat = dist_inter_to_segment2
+
+ i, j = segments_info
+ min_dist, max_dist = check_dist_mat[i, j, :]
+ connect_dist = dist_segments[connect_idx]
+ if max_dist > connect_dist:
+ return 'outside', min_dist, 0, 1
+ else:
+ return 'inside', min_dist, -1, -1
+
+ top_square = None
+
+ try:
+ map_size = input_shape[0] / 2
+ squares = np.array(square_list).reshape([-1, 4, 2])
+ score_array = []
+ connect_array = np.array(connect_list)
+ segments_array = np.array(segments_list).reshape([-1, 4, 2])
+
+ # get degree of corners:
+ squares_rollup = np.roll(squares, 1, axis=1)
+ squares_rolldown = np.roll(squares, -1, axis=1)
+ vec1 = squares_rollup - squares
+ normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
+ vec2 = squares_rolldown - squares
+ normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
+ inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
+ squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
+
+ # get square score
+ overlap_scores = []
+ degree_scores = []
+ length_scores = []
+
+ for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+
+ # segments: [4, 2]
+ # connects: [4]
+ '''
+
+ ###################################### OVERLAP SCORES
+ cover = 0
+ perimeter = 0
+ # check 0 > 1 > 2 > 3
+ square_length = []
+
+ for start_idx in range(4):
+ end_idx = (start_idx + 1) % 4
+
+ connect_idx = connects[start_idx] # segment idx of segment01
+ start_segments = segments[start_idx]
+ end_segments = segments[end_idx]
+
+ start_point = square[start_idx]
+ end_point = square[end_idx]
+
+ # check whether outside or inside
+ start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
+ connect_idx)
+ end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
+
+ cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
+ perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
+
+ square_length.append(
+ dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
+
+ overlap_scores.append(cover / perimeter)
+ ######################################
+ ###################################### DEGREE SCORES
+ '''
+ deg0 vs deg2
+ deg1 vs deg3
+ '''
+ deg0, deg1, deg2, deg3 = degree
+ deg_ratio1 = deg0 / deg2
+ if deg_ratio1 > 1.0:
+ deg_ratio1 = 1 / deg_ratio1
+ deg_ratio2 = deg1 / deg3
+ if deg_ratio2 > 1.0:
+ deg_ratio2 = 1 / deg_ratio2
+ degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
+ ######################################
+ ###################################### LENGTH SCORES
+ '''
+ len0 vs len2
+ len1 vs len3
+ '''
+ len0, len1, len2, len3 = square_length
+ len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
+ len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
+ length_scores.append((len_ratio1 + len_ratio2) / 2)
+
+ ######################################
+
+ overlap_scores = np.array(overlap_scores)
+ overlap_scores /= np.max(overlap_scores)
+
+ degree_scores = np.array(degree_scores)
+ # degree_scores /= np.max(degree_scores)
+
+ length_scores = np.array(length_scores)
+
+ ###################################### AREA SCORES
+ area_scores = np.reshape(squares, [-1, 4, 2])
+ area_x = area_scores[:, :, 0]
+ area_y = area_scores[:, :, 1]
+ correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
+ area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
+ area_scores = 0.5 * np.abs(area_scores + correction)
+ area_scores /= (map_size * map_size) # np.max(area_scores)
+ ######################################
+
+ ###################################### CENTER SCORES
+ centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
+ # squares: [n, 4, 2]
+ square_centers = np.mean(squares, axis=1) # [n, 2]
+ center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
+ center_scores = center2center / (map_size / np.sqrt(2.0))
+
+ '''
+ score_w = [overlap, degree, area, center, length]
+ '''
+ score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
+ score_array = params['w_overlap'] * overlap_scores \
+ + params['w_degree'] * degree_scores \
+ + params['w_area'] * area_scores \
+ - params['w_center'] * center_scores \
+ + params['w_length'] * length_scores
+
+ best_square = []
+
+ sorted_idx = np.argsort(score_array)[::-1]
+ score_array = score_array[sorted_idx]
+ squares = squares[sorted_idx]
+
+ except Exception as e:
+ pass
+
+ '''return list
+ merged_lines, squares, scores
+ '''
+
+ try:
+ new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
+ new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
+ except:
+ new_segments = []
+
+ try:
+ squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
+ squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ squares = []
+ score_array = []
+
+ try:
+ inter_points = np.array(inter_points)
+ inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
+ inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ inter_points = []
+
+ return new_segments, squares, score_array, inter_points
diff --git a/lib/model_zoo/controlnet_annotator/openpose/LICENSE b/lib/model_zoo/controlnet_annotator/openpose/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/openpose/LICENSE
@@ -0,0 +1,108 @@
+OPENPOSE: MULTIPERSON KEYPOINT DETECTION
+SOFTWARE LICENSE AGREEMENT
+ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
+
+BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
+
+This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
+
+RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
+Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
+non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
+
+CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
+
+COPYRIGHT: The Software is owned by Licensor and is protected by United
+States copyright laws and applicable international treaties and/or conventions.
+
+PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
+
+DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
+
+BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
+
+USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
+
+You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
+
+ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
+
+TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
+
+The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
+
+FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
+
+DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
+
+SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
+
+EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
+
+EXPORT REGULATION: Licensee agrees to comply with any and all applicable
+U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
+
+SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
+
+NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
+
+GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
+
+ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
+
+
+
+************************************************************************
+
+THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
+
+This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
+
+1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
+
+COPYRIGHT
+
+All contributions by the University of California:
+Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+All rights reserved.
+
+All other contributions:
+Copyright (c) 2014-2017, the respective contributors
+All rights reserved.
+
+Caffe uses a shared copyright model: each contributor holds copyright over
+their contributions to Caffe. The project versioning records all such
+contribution and copyright details. If a contributor wants to further mark
+their specific copyright on a particular contribution, they should indicate
+their copyright solely in the commit message of the change when it is
+committed.
+
+LICENSE
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+CONTRIBUTION AGREEMENT
+
+By contributing to the BVLC/caffe repository through pull-request, comment,
+or otherwise, the contributor releases their content to the
+license and copyright terms herein.
+
+************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/openpose/__init__.py b/lib/model_zoo/controlnet_annotator/openpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae4fcd9683118a9ffb966ccc517bd3451dce28ab
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/openpose/__init__.py
@@ -0,0 +1,320 @@
+# Openpose
+# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
+# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
+# 3rd Edited by ControlNet
+# 4th Edited by ControlNet (added face and correct hands)
+# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
+# This preprocessor is licensed by CMU for non-commercial use only.
+
+
+import os
+os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
+
+import json
+import torch
+import numpy as np
+from . import util
+from .body import Body, BodyResult, Keypoint
+from .hand import Hand
+from .face import Face
+
+models_path = "pretrained/controlnet/preprocess"
+
+from typing import NamedTuple, Tuple, List, Callable, Union
+
+body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth"
+hand_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth"
+face_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth"
+
+HandResult = List[Keypoint]
+FaceResult = List[Keypoint]
+
+class PoseResult(NamedTuple):
+ body: BodyResult
+ left_hand: Union[HandResult, None]
+ right_hand: Union[HandResult, None]
+ face: Union[FaceResult, None]
+
+def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
+ """
+ Draw the detected poses on an empty canvas.
+
+ Args:
+ poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
+ H (int): The height of the canvas.
+ W (int): The width of the canvas.
+ draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
+ draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
+ draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
+
+ Returns:
+ numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
+ """
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
+
+ for pose in poses:
+ if draw_body:
+ canvas = util.draw_bodypose(canvas, pose.body.keypoints)
+
+ if draw_hand:
+ canvas = util.draw_handpose(canvas, pose.left_hand)
+ canvas = util.draw_handpose(canvas, pose.right_hand)
+
+ if draw_face:
+ canvas = util.draw_facepose(canvas, pose.face)
+
+ return canvas
+
+def encode_poses_as_json(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str:
+ """ Encode the pose as a JSON string following openpose JSON output format:
+ https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md
+ """
+ def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]:
+ if not keypoints:
+ return None
+
+ return [
+ value
+ for keypoint in keypoints
+ for value in (
+ [float(keypoint.x), float(keypoint.y), 1.0]
+ if keypoint is not None
+ else [0.0, 0.0, 0.0]
+ )
+ ]
+
+ return json.dumps({
+ 'people': [
+ {
+ 'pose_keypoints_2d': compress_keypoints(pose.body.keypoints),
+ "face_keypoints_2d": compress_keypoints(pose.face),
+ "hand_left_keypoints_2d": compress_keypoints(pose.left_hand),
+ "hand_right_keypoints_2d":compress_keypoints(pose.right_hand),
+ }
+ for pose in poses
+ ],
+ 'canvas_height': canvas_height,
+ 'canvas_width': canvas_width,
+ }, indent=4)
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+ Returns:
+ str: The path to the downloaded file.
+ """
+ from torch.hub import download_url_to_file, get_dir
+ from urllib.parse import urlparse
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+class OpenposeDetector:
+ """
+ A class for detecting human poses in images using the Openpose model.
+
+ Attributes:
+ model_dir (str): Path to the directory where the pose models are stored.
+ """
+ model_dir = os.path.join(models_path, "openpose")
+
+ def __init__(self, device):
+ self.device = device
+ self.body_estimation = None
+ self.hand_estimation = None
+ self.face_estimation = None
+
+ def load_model(self):
+ """
+ Load the Openpose body, hand, and face models.
+ """
+ body_modelpath = os.path.join(self.model_dir, "body_pose_model.pth")
+ hand_modelpath = os.path.join(self.model_dir, "hand_pose_model.pth")
+ face_modelpath = os.path.join(self.model_dir, "facenet.pth")
+
+ if not os.path.exists(body_modelpath):
+ load_file_from_url(body_model_path, model_dir=self.model_dir)
+
+ if not os.path.exists(hand_modelpath):
+ load_file_from_url(hand_model_path, model_dir=self.model_dir)
+
+ if not os.path.exists(face_modelpath):
+ load_file_from_url(face_model_path, model_dir=self.model_dir)
+
+ self.body_estimation = Body(body_modelpath)
+ self.hand_estimation = Hand(hand_modelpath)
+ self.face_estimation = Face(face_modelpath)
+
+ def unload_model(self):
+ """
+ Unload the Openpose models by moving them to the CPU.
+ """
+ if self.body_estimation is not None:
+ self.body_estimation.model.to("cpu")
+ self.hand_estimation.model.to("cpu")
+ self.face_estimation.model.to("cpu")
+
+ def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
+ left_hand = None
+ right_hand = None
+ H, W, _ = oriImg.shape
+ for x, y, w, is_left in util.handDetect(body, oriImg):
+ peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32)
+ if peaks.ndim == 2 and peaks.shape[1] == 2:
+ peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
+ peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
+
+ hand_result = [
+ Keypoint(x=peak[0], y=peak[1])
+ for peak in peaks
+ ]
+
+ if is_left:
+ left_hand = hand_result
+ else:
+ right_hand = hand_result
+
+ return left_hand, right_hand
+
+ def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]:
+ face = util.faceDetect(body, oriImg)
+ if face is None:
+ return None
+
+ x, y, w = face
+ H, W, _ = oriImg.shape
+ heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :])
+ peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32)
+ if peaks.ndim == 2 and peaks.shape[1] == 2:
+ peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
+ peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
+ return [
+ Keypoint(x=peak[0], y=peak[1])
+ for peak in peaks
+ ]
+
+ return None
+
+ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]:
+ """
+ Detect poses in the given image.
+ Args:
+ oriImg (numpy.ndarray): The input image for pose detection.
+ include_hand (bool, optional): Whether to include hand detection. Defaults to False.
+ include_face (bool, optional): Whether to include face detection. Defaults to False.
+
+ Returns:
+ List[PoseResult]: A list of PoseResult objects containing the detected poses.
+ """
+ if self.body_estimation is None:
+ self.load_model()
+
+ self.body_estimation.model.to(self.device)
+ self.hand_estimation.model.to(self.device)
+ self.face_estimation.model.to(self.device)
+
+ self.body_estimation.cn_device = self.device
+ self.hand_estimation.cn_device = self.device
+ self.face_estimation.cn_device = self.device
+
+ oriImg = oriImg[:, :, ::-1].copy()
+ H, W, C = oriImg.shape
+ with torch.no_grad():
+ candidate, subset = self.body_estimation(oriImg)
+ bodies = self.body_estimation.format_body_result(candidate, subset)
+
+ results = []
+ for body in bodies:
+ left_hand, right_hand, face = (None,) * 3
+ if include_hand:
+ left_hand, right_hand = self.detect_hands(body, oriImg)
+ if include_face:
+ face = self.detect_face(body, oriImg)
+
+ results.append(PoseResult(BodyResult(
+ keypoints=[
+ Keypoint(
+ x=keypoint.x / float(W),
+ y=keypoint.y / float(H)
+ ) if keypoint is not None else None
+ for keypoint in body.keypoints
+ ],
+ total_score=body.total_score,
+ total_parts=body.total_parts
+ ), left_hand, right_hand, face))
+
+ return results
+
+ def __call__(
+ self, oriImg, include_body=True, include_hand=False, include_face=False,
+ json_pose_callback: Callable[[str], None] = None,
+ ):
+ """
+ Detect and draw poses in the given image.
+
+ Args:
+ oriImg (numpy.ndarray): The input image for pose detection and drawing.
+ include_body (bool, optional): Whether to include body keypoints. Defaults to True.
+ include_hand (bool, optional): Whether to include hand keypoints. Defaults to False.
+ include_face (bool, optional): Whether to include face keypoints. Defaults to False.
+ json_pose_callback (Callable, optional): A callback that accepts the pose JSON string.
+
+ Returns:
+ numpy.ndarray: The image with detected and drawn poses.
+ """
+ H, W, _ = oriImg.shape
+ poses = self.detect_poses(oriImg, include_hand, include_face)
+ if json_pose_callback:
+ json_pose_callback(encode_poses_as_json(poses, H, W))
+ return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
+
+class OpenposeModel(object):
+ def __init__(self) -> None:
+ self.model_openpose = None
+
+ def run_model(
+ self,
+ img: np.ndarray,
+ include_body: bool,
+ include_hand: bool,
+ include_face: bool,
+ json_pose_callback: Callable[[str], None] = None,
+ device = 'cpu', ):
+
+ if json_pose_callback is None:
+ json_pose_callback = lambda x: None
+
+ if self.model_openpose is None:
+ self.model_openpose = OpenposeDetector(device=device)
+
+ return self.model_openpose(
+ img,
+ include_body=include_body,
+ include_hand=include_hand,
+ include_face=include_face,
+ json_pose_callback=json_pose_callback)
+
+ def unload(self):
+ if self.model_openpose is not None:
+ self.model_openpose.unload_model()
diff --git a/lib/model_zoo/controlnet_annotator/openpose/body.py b/lib/model_zoo/controlnet_annotator/openpose/body.py
new file mode 100644
index 0000000000000000000000000000000000000000..11b10b8db047be9b88f5f0756592fdbae3d85027
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/openpose/body.py
@@ -0,0 +1,278 @@
+import cv2
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from torchvision import transforms
+from typing import NamedTuple, List, Union
+
+from . import util
+from .model import bodypose_model
+
+class Keypoint(NamedTuple):
+ x: float
+ y: float
+ score: float = 1.0
+ id: int = -1
+
+
+class BodyResult(NamedTuple):
+ # Note: Using `Union` instead of `|` operator as the ladder is a Python
+ # 3.10 feature.
+ # Annotator code should be Python 3.8 Compatible, as controlnet repo uses
+ # Python 3.8 environment.
+ # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
+ keypoints: List[Union[Keypoint, None]]
+ total_score: float
+ total_parts: int
+
+
+class Body(object):
+ def __init__(self, model_path):
+ self.model = bodypose_model()
+ # if torch.cuda.is_available():
+ # self.model = self.model.cuda()
+ # print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImg):
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
+ scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre1 = 0.1
+ thre2 = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ data = data.to(self.cn_device)
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
+
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
+ paf = util.smart_resize_k(paf, fx=stride, fy=stride)
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
+
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+ paf_avg += + paf / len(multiplier)
+
+ all_peaks = []
+ peak_counter = 0
+
+ for part in range(18):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+
+ map_left = np.zeros(one_heatmap.shape)
+ map_left[1:, :] = one_heatmap[:-1, :]
+ map_right = np.zeros(one_heatmap.shape)
+ map_right[:-1, :] = one_heatmap[1:, :]
+ map_up = np.zeros(one_heatmap.shape)
+ map_up[:, 1:] = one_heatmap[:, :-1]
+ map_down = np.zeros(one_heatmap.shape)
+ map_down[:, :-1] = one_heatmap[:, 1:]
+
+ peaks_binary = np.logical_and.reduce(
+ (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
+ peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
+ peak_id = range(peak_counter, peak_counter + len(peaks))
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
+
+ all_peaks.append(peaks_with_score_and_id)
+ peak_counter += len(peaks)
+
+ # find connection in the specified sequence, center 29 is in the position 15
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+ # the middle joints heatmap correpondence
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+ [55, 56], [37, 38], [45, 46]]
+
+ connection_all = []
+ special_k = []
+ mid_num = 10
+
+ for k in range(len(mapIdx)):
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+ candA = all_peaks[limbSeq[k][0] - 1]
+ candB = all_peaks[limbSeq[k][1] - 1]
+ nA = len(candA)
+ nB = len(candB)
+ indexA, indexB = limbSeq[k]
+ if (nA != 0 and nB != 0):
+ connection_candidate = []
+ for i in range(nA):
+ for j in range(nB):
+ vec = np.subtract(candB[j][:2], candA[i][:2])
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+ norm = max(0.001, norm)
+ vec = np.divide(vec, norm)
+
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+ for I in range(len(startend))])
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+ for I in range(len(startend))])
+
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+ criterion2 = score_with_dist_prior > 0
+ if criterion1 and criterion2:
+ connection_candidate.append(
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+ connection = np.zeros((0, 5))
+ for c in range(len(connection_candidate)):
+ i, j, s = connection_candidate[c][0:3]
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+ if (len(connection) >= min(nA, nB)):
+ break
+
+ connection_all.append(connection)
+ else:
+ special_k.append(k)
+ connection_all.append([])
+
+ # last number in each row is the total parts number of that person
+ # the second last number in each row is the score of the overall configuration
+ subset = -1 * np.ones((0, 20))
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
+
+ for k in range(len(mapIdx)):
+ if k not in special_k:
+ partAs = connection_all[k][:, 0]
+ partBs = connection_all[k][:, 1]
+ indexA, indexB = np.array(limbSeq[k]) - 1
+
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
+ found = 0
+ subset_idx = [-1, -1]
+ for j in range(len(subset)): # 1:size(subset,1):
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+ subset_idx[found] = j
+ found += 1
+
+ if found == 1:
+ j = subset_idx[0]
+ if subset[j][indexB] != partBs[i]:
+ subset[j][indexB] = partBs[i]
+ subset[j][-1] += 1
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+ elif found == 2: # if found 2 and disjoint, merge them
+ j1, j2 = subset_idx
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
+ subset[j1][-2:] += subset[j2][-2:]
+ subset[j1][-2] += connection_all[k][i][2]
+ subset = np.delete(subset, j2, 0)
+ else: # as like found == 1
+ subset[j1][indexB] = partBs[i]
+ subset[j1][-1] += 1
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+
+ # if find no partA in the subset, create a new subset
+ elif not found and k < 17:
+ row = -1 * np.ones(20)
+ row[indexA] = partAs[i]
+ row[indexB] = partBs[i]
+ row[-1] = 2
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+ subset = np.vstack([subset, row])
+ # delete some rows of subset which has few parts occur
+ deleteIdx = []
+ for i in range(len(subset)):
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+ deleteIdx.append(i)
+ subset = np.delete(subset, deleteIdx, axis=0)
+
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+ # candidate: x, y, score, id
+ return candidate, subset
+
+ @staticmethod
+ def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
+ """
+ Format the body results from the candidate and subset arrays into a list of BodyResult objects.
+
+ Args:
+ candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
+ for each body part.
+ subset (np.ndarray): An array of subsets containing indices to the candidate array for each
+ person detected. The last two columns of each row hold the total score and total parts
+ of the person.
+
+ Returns:
+ List[BodyResult]: A list of BodyResult objects, where each object represents a person with
+ detected keypoints, total score, and total parts.
+ """
+ return [
+ BodyResult(
+ keypoints=[
+ Keypoint(
+ x=candidate[candidate_index][0],
+ y=candidate[candidate_index][1],
+ score=candidate[candidate_index][2],
+ id=candidate[candidate_index][3]
+ ) if candidate_index != -1 else None
+ for candidate_index in person[:18].astype(int)
+ ],
+ total_score=person[18],
+ total_parts=person[19]
+ )
+ for person in subset
+ ]
+
+
+if __name__ == "__main__":
+ body_estimation = Body('../model/body_pose_model.pth')
+
+ test_image = '../images/ski.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ candidate, subset = body_estimation(oriImg)
+ bodies = body_estimation.format_body_result(candidate, subset)
+
+ canvas = oriImg
+ for body in bodies:
+ canvas = util.draw_bodypose(canvas, body)
+
+ plt.imshow(canvas[:, :, [2, 1, 0]])
+ plt.show()
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/openpose/face.py b/lib/model_zoo/controlnet_annotator/openpose/face.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3c46d77664aa9fa91c63785a1485a396f05cacc
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/openpose/face.py
@@ -0,0 +1,362 @@
+import logging
+import numpy as np
+from torchvision.transforms import ToTensor, ToPILImage
+import torch
+import torch.nn.functional as F
+import cv2
+
+from . import util
+from torch.nn import Conv2d, Module, ReLU, MaxPool2d, init
+
+
+class FaceNet(Module):
+ """Model the cascading heatmaps. """
+ def __init__(self):
+ super(FaceNet, self).__init__()
+ # cnn to make feature map
+ self.relu = ReLU()
+ self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
+ self.conv1_1 = Conv2d(in_channels=3, out_channels=64,
+ kernel_size=3, stride=1, padding=1)
+ self.conv1_2 = Conv2d(
+ in_channels=64, out_channels=64, kernel_size=3, stride=1,
+ padding=1)
+ self.conv2_1 = Conv2d(
+ in_channels=64, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+ self.conv2_2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_1 = Conv2d(
+ in_channels=128, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_2 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_3 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv3_4 = Conv2d(
+ in_channels=256, out_channels=256, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_1 = Conv2d(
+ in_channels=256, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_2 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_3 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv4_4 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_1 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_2 = Conv2d(
+ in_channels=512, out_channels=512, kernel_size=3, stride=1,
+ padding=1)
+ self.conv5_3_CPM = Conv2d(
+ in_channels=512, out_channels=128, kernel_size=3, stride=1,
+ padding=1)
+
+ # stage1
+ self.conv6_1_CPM = Conv2d(
+ in_channels=128, out_channels=512, kernel_size=1, stride=1,
+ padding=0)
+ self.conv6_2_CPM = Conv2d(
+ in_channels=512, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage2
+ self.Mconv1_stage2 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage2 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage2 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage3
+ self.Mconv1_stage3 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage3 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage3 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage4
+ self.Mconv1_stage4 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage4 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage4 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage5
+ self.Mconv1_stage5 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage5 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage5 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ # stage6
+ self.Mconv1_stage6 = Conv2d(
+ in_channels=199, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv2_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv3_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv4_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv5_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=7, stride=1,
+ padding=3)
+ self.Mconv6_stage6 = Conv2d(
+ in_channels=128, out_channels=128, kernel_size=1, stride=1,
+ padding=0)
+ self.Mconv7_stage6 = Conv2d(
+ in_channels=128, out_channels=71, kernel_size=1, stride=1,
+ padding=0)
+
+ for m in self.modules():
+ if isinstance(m, Conv2d):
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ """Return a list of heatmaps."""
+ heatmaps = []
+
+ h = self.relu(self.conv1_1(x))
+ h = self.relu(self.conv1_2(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv2_1(h))
+ h = self.relu(self.conv2_2(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv3_1(h))
+ h = self.relu(self.conv3_2(h))
+ h = self.relu(self.conv3_3(h))
+ h = self.relu(self.conv3_4(h))
+ h = self.max_pooling_2d(h)
+ h = self.relu(self.conv4_1(h))
+ h = self.relu(self.conv4_2(h))
+ h = self.relu(self.conv4_3(h))
+ h = self.relu(self.conv4_4(h))
+ h = self.relu(self.conv5_1(h))
+ h = self.relu(self.conv5_2(h))
+ h = self.relu(self.conv5_3_CPM(h))
+ feature_map = h
+
+ # stage1
+ h = self.relu(self.conv6_1_CPM(h))
+ h = self.conv6_2_CPM(h)
+ heatmaps.append(h)
+
+ # stage2
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage2(h))
+ h = self.relu(self.Mconv2_stage2(h))
+ h = self.relu(self.Mconv3_stage2(h))
+ h = self.relu(self.Mconv4_stage2(h))
+ h = self.relu(self.Mconv5_stage2(h))
+ h = self.relu(self.Mconv6_stage2(h))
+ h = self.Mconv7_stage2(h)
+ heatmaps.append(h)
+
+ # stage3
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage3(h))
+ h = self.relu(self.Mconv2_stage3(h))
+ h = self.relu(self.Mconv3_stage3(h))
+ h = self.relu(self.Mconv4_stage3(h))
+ h = self.relu(self.Mconv5_stage3(h))
+ h = self.relu(self.Mconv6_stage3(h))
+ h = self.Mconv7_stage3(h)
+ heatmaps.append(h)
+
+ # stage4
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage4(h))
+ h = self.relu(self.Mconv2_stage4(h))
+ h = self.relu(self.Mconv3_stage4(h))
+ h = self.relu(self.Mconv4_stage4(h))
+ h = self.relu(self.Mconv5_stage4(h))
+ h = self.relu(self.Mconv6_stage4(h))
+ h = self.Mconv7_stage4(h)
+ heatmaps.append(h)
+
+ # stage5
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage5(h))
+ h = self.relu(self.Mconv2_stage5(h))
+ h = self.relu(self.Mconv3_stage5(h))
+ h = self.relu(self.Mconv4_stage5(h))
+ h = self.relu(self.Mconv5_stage5(h))
+ h = self.relu(self.Mconv6_stage5(h))
+ h = self.Mconv7_stage5(h)
+ heatmaps.append(h)
+
+ # stage6
+ h = torch.cat([h, feature_map], dim=1) # channel concat
+ h = self.relu(self.Mconv1_stage6(h))
+ h = self.relu(self.Mconv2_stage6(h))
+ h = self.relu(self.Mconv3_stage6(h))
+ h = self.relu(self.Mconv4_stage6(h))
+ h = self.relu(self.Mconv5_stage6(h))
+ h = self.relu(self.Mconv6_stage6(h))
+ h = self.Mconv7_stage6(h)
+ heatmaps.append(h)
+
+ return heatmaps
+
+
+LOG = logging.getLogger(__name__)
+TOTEN = ToTensor()
+TOPIL = ToPILImage()
+
+
+params = {
+ 'gaussian_sigma': 2.5,
+ 'inference_img_size': 736, # 368, 736, 1312
+ 'heatmap_peak_thresh': 0.1,
+ 'crop_scale': 1.5,
+ 'line_indices': [
+ [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
+ [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13],
+ [13, 14], [14, 15], [15, 16],
+ [17, 18], [18, 19], [19, 20], [20, 21],
+ [22, 23], [23, 24], [24, 25], [25, 26],
+ [27, 28], [28, 29], [29, 30],
+ [31, 32], [32, 33], [33, 34], [34, 35],
+ [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
+ [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
+ [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54],
+ [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],
+ [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66],
+ [66, 67], [67, 60]
+ ],
+}
+
+
+class Face(object):
+ """
+ The OpenPose face landmark detector model.
+
+ Args:
+ inference_size: set the size of the inference image size, suggested:
+ 368, 736, 1312, default 736
+ gaussian_sigma: blur the heatmaps, default 2.5
+ heatmap_peak_thresh: return landmark if over threshold, default 0.1
+
+ """
+ def __init__(self, face_model_path,
+ inference_size=None,
+ gaussian_sigma=None,
+ heatmap_peak_thresh=None):
+ self.inference_size = inference_size or params["inference_img_size"]
+ self.sigma = gaussian_sigma or params['gaussian_sigma']
+ self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
+ self.model = FaceNet()
+ self.model.load_state_dict(torch.load(face_model_path))
+ # if torch.cuda.is_available():
+ # self.model = self.model.cuda()
+ # print('cuda')
+ self.model.eval()
+
+ def __call__(self, face_img):
+ H, W, C = face_img.shape
+
+ w_size = 384
+ x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
+
+ x_data = x_data.to(self.cn_device)
+
+ with torch.no_grad():
+ hs = self.model(x_data[None, ...])
+ heatmaps = F.interpolate(
+ hs[-1],
+ (H, W),
+ mode='bilinear', align_corners=True).cpu().numpy()[0]
+ return heatmaps
+
+ def compute_peaks_from_heatmaps(self, heatmaps):
+ all_peaks = []
+ for part in range(heatmaps.shape[0]):
+ map_ori = heatmaps[part].copy()
+ binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
+
+ if np.sum(binary) == 0:
+ continue
+
+ positions = np.where(binary > 0.5)
+ intensities = map_ori[positions]
+ mi = np.argmax(intensities)
+ y, x = positions[0][mi], positions[1][mi]
+ all_peaks.append([x, y])
+
+ return np.array(all_peaks)
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/openpose/hand.py b/lib/model_zoo/controlnet_annotator/openpose/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..74767def506c72612954fe3b79056d17a83b1e16
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/openpose/hand.py
@@ -0,0 +1,94 @@
+import cv2
+import json
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from skimage.measure import label
+
+from .model import handpose_model
+from . import util
+
+class Hand(object):
+ def __init__(self, model_path):
+ self.model = handpose_model()
+ # if torch.cuda.is_available():
+ # self.model = self.model.cuda()
+ # print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImgRaw):
+ scale_search = [0.5, 1.0, 1.5, 2.0]
+ # scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre = 0.05
+ multiplier = [x * boxsize for x in scale_search]
+
+ wsize = 128
+ heatmap_avg = np.zeros((wsize, wsize, 22))
+
+ Hr, Wr, Cr = oriImgRaw.shape
+
+ oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = util.smart_resize(oriImg, (scale, scale))
+
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+
+ with torch.no_grad():
+ data = data.to(self.cn_device)
+ output = self.model(data).cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = util.smart_resize(heatmap, (wsize, wsize))
+
+ heatmap_avg += heatmap / len(multiplier)
+
+ all_peaks = []
+ for part in range(21):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
+
+ if np.sum(binary) == 0:
+ all_peaks.append([0, 0])
+ continue
+ label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
+ max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
+ label_img[label_img != max_index] = 0
+ map_ori[label_img == 0] = 0
+
+ y, x = util.npmax(map_ori)
+ y = int(float(y) * float(Hr) / float(wsize))
+ x = int(float(x) * float(Wr) / float(wsize))
+ all_peaks.append([x, y])
+ return np.array(all_peaks)
+
+if __name__ == "__main__":
+ hand_estimation = Hand('../model/hand_pose_model.pth')
+
+ # test_image = '../images/hand.jpg'
+ test_image = '../images/hand.jpg'
+ oriImg = cv2.imread(test_image) # B,G,R order
+ peaks = hand_estimation(oriImg)
+ canvas = util.draw_handpose(oriImg, peaks, True)
+ cv2.imshow('', canvas)
+ cv2.waitKey(0)
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/openpose/model.py b/lib/model_zoo/controlnet_annotator/openpose/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..72dc79ad857933a7c108d21494d6395572b816e6
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/openpose/model.py
@@ -0,0 +1,218 @@
+import torch
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+def make_layers(block, no_relu_layers):
+ layers = []
+ for layer_name, v in block.items():
+ if 'pool' in layer_name:
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
+ padding=v[2])
+ layers.append((layer_name, layer))
+ else:
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
+ kernel_size=v[2], stride=v[3],
+ padding=v[4])
+ layers.append((layer_name, conv2d))
+ if layer_name not in no_relu_layers:
+ layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
+
+ return nn.Sequential(OrderedDict(layers))
+
+class bodypose_model(nn.Module):
+ def __init__(self):
+ super(bodypose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+ blocks = {}
+ block0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]),
+ ('conv4_4_CPM', [256, 128, 3, 1, 1])
+ ])
+
+
+ # Stage 1
+ block1_1 = OrderedDict([
+ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
+ ])
+
+ block1_2 = OrderedDict([
+ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
+ ])
+ blocks['block1_1'] = block1_1
+ blocks['block1_2'] = block1_2
+
+ self.model0 = make_layers(block0, no_relu_layers)
+
+ # Stages 2 - 6
+ for i in range(2, 7):
+ blocks['block%d_1' % i] = OrderedDict([
+ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
+ ])
+
+ blocks['block%d_2' % i] = OrderedDict([
+ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_1 = blocks['block1_1']
+ self.model2_1 = blocks['block2_1']
+ self.model3_1 = blocks['block3_1']
+ self.model4_1 = blocks['block4_1']
+ self.model5_1 = blocks['block5_1']
+ self.model6_1 = blocks['block6_1']
+
+ self.model1_2 = blocks['block1_2']
+ self.model2_2 = blocks['block2_2']
+ self.model3_2 = blocks['block3_2']
+ self.model4_2 = blocks['block4_2']
+ self.model5_2 = blocks['block5_2']
+ self.model6_2 = blocks['block6_2']
+
+
+ def forward(self, x):
+
+ out1 = self.model0(x)
+
+ out1_1 = self.model1_1(out1)
+ out1_2 = self.model1_2(out1)
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
+
+ out2_1 = self.model2_1(out2)
+ out2_2 = self.model2_2(out2)
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
+
+ out3_1 = self.model3_1(out3)
+ out3_2 = self.model3_2(out3)
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
+
+ out4_1 = self.model4_1(out4)
+ out4_2 = self.model4_2(out4)
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
+
+ out5_1 = self.model5_1(out5)
+ out5_2 = self.model5_2(out5)
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
+
+ out6_1 = self.model6_1(out6)
+ out6_2 = self.model6_2(out6)
+
+ return out6_1, out6_2
+
+class handpose_model(nn.Module):
+ def __init__(self):
+ super(handpose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+ # stage 1
+ block1_0 = OrderedDict([
+ ('conv1_1', [3, 64, 3, 1, 1]),
+ ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3', [512, 512, 3, 1, 1]),
+ ('conv4_4', [512, 512, 3, 1, 1]),
+ ('conv5_1', [512, 512, 3, 1, 1]),
+ ('conv5_2', [512, 512, 3, 1, 1]),
+ ('conv5_3_CPM', [512, 128, 3, 1, 1])
+ ])
+
+ block1_1 = OrderedDict([
+ ('conv6_1_CPM', [128, 512, 1, 1, 0]),
+ ('conv6_2_CPM', [512, 22, 1, 1, 0])
+ ])
+
+ blocks = {}
+ blocks['block1_0'] = block1_0
+ blocks['block1_1'] = block1_1
+
+ # stage 2-6
+ for i in range(2, 7):
+ blocks['block%d' % i] = OrderedDict([
+ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
+ ])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_0 = blocks['block1_0']
+ self.model1_1 = blocks['block1_1']
+ self.model2 = blocks['block2']
+ self.model3 = blocks['block3']
+ self.model4 = blocks['block4']
+ self.model5 = blocks['block5']
+ self.model6 = blocks['block6']
+
+ def forward(self, x):
+ out1_0 = self.model1_0(x)
+ out1_1 = self.model1_1(out1_0)
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
+ out_stage2 = self.model2(concat_stage2)
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+ out_stage3 = self.model3(concat_stage3)
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+ out_stage4 = self.model4(concat_stage4)
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+ out_stage5 = self.model5(concat_stage5)
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+ out_stage6 = self.model6(concat_stage6)
+ return out_stage6
+
diff --git a/lib/model_zoo/controlnet_annotator/openpose/util.py b/lib/model_zoo/controlnet_annotator/openpose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0851ca409863dcee4bf731a47b472992569dd68
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/openpose/util.py
@@ -0,0 +1,383 @@
+import math
+import numpy as np
+import matplotlib
+import cv2
+from typing import List, Tuple, Union
+
+from .body import BodyResult, Keypoint
+
+eps = 0.01
+
+
+def smart_resize(x, s):
+ Ht, Wt = s
+ if x.ndim == 2:
+ Ho, Wo = x.shape
+ Co = 1
+ else:
+ Ho, Wo, Co = x.shape
+ if Co == 3 or Co == 1:
+ k = float(Ht + Wt) / float(Ho + Wo)
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
+ else:
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
+
+
+def smart_resize_k(x, fx, fy):
+ if x.ndim == 2:
+ Ho, Wo = x.shape
+ Co = 1
+ else:
+ Ho, Wo, Co = x.shape
+ Ht, Wt = Ho * fy, Wo * fx
+ if Co == 3 or Co == 1:
+ k = float(Ht + Wt) / float(Ho + Wo)
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
+ else:
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
+
+
+def padRightDownCorner(img, stride, padValue):
+ h = img.shape[0]
+ w = img.shape[1]
+
+ pad = 4 * [None]
+ pad[0] = 0 # up
+ pad[1] = 0 # left
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+ img_padded = img
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+ return img_padded, pad
+
+
+def transfer(model, model_weights):
+ transfered_model_weights = {}
+ for weights_name in model.state_dict().keys():
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+ return transfered_model_weights
+
+
+def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
+ """
+ Draw keypoints and limbs representing body pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
+ keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ H, W, C = canvas.shape
+ stickwidth = 4
+
+ limbSeq = [
+ [2, 3], [2, 6], [3, 4], [4, 5],
+ [6, 7], [7, 8], [2, 9], [9, 10],
+ [10, 11], [2, 12], [12, 13], [13, 14],
+ [2, 1], [1, 15], [15, 17], [1, 16],
+ [16, 18],
+ ]
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+ for (k1_index, k2_index), color in zip(limbSeq, colors):
+ keypoint1 = keypoints[k1_index - 1]
+ keypoint2 = keypoints[k2_index - 1]
+
+ if keypoint1 is None or keypoint2 is None:
+ continue
+
+ Y = np.array([keypoint1.x, keypoint2.x]) * float(W)
+ X = np.array([keypoint1.y, keypoint2.y]) * float(H)
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
+
+ for keypoint, color in zip(keypoints, colors):
+ if keypoint is None:
+ continue
+
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
+
+ return canvas
+
+
+def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
+ """
+ Draw keypoints and connections representing hand pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
+ or None if no keypoints are present.
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ if not keypoints:
+ return canvas
+
+ H, W, C = canvas.shape
+
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+ for ie, (e1, e2) in enumerate(edges):
+ k1 = keypoints[e1]
+ k2 = keypoints[e2]
+ if k1 is None or k2 is None:
+ continue
+
+ x1 = int(k1.x * W)
+ y1 = int(k1.y * H)
+ x2 = int(k2.x * W)
+ y2 = int(k2.y * H)
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
+
+ for keypoint in keypoints:
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+ return canvas
+
+
+def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
+ """
+ Draw keypoints representing face pose on a given canvas.
+
+ Args:
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose.
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn
+ or None if no keypoints are present.
+
+ Returns:
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose.
+
+ Note:
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
+ """
+ if not keypoints:
+ return canvas
+
+ H, W, C = canvas.shape
+ for keypoint in keypoints:
+ x, y = keypoint.x, keypoint.y
+ x = int(x * W)
+ y = int(y * H)
+ if x > eps and y > eps:
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
+ return canvas
+
+
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]:
+ """
+ Detect hands in the input body pose keypoints and calculate the bounding box for each hand.
+
+ Args:
+ body (BodyResult): A BodyResult object containing the detected body pose keypoints.
+ oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
+
+ Returns:
+ List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left
+ corner of the bounding box, the width (height) of the bounding box, and
+ a boolean flag indicating whether the hand is a left hand (True) or a
+ right hand (False).
+
+ Notes:
+ - The width and height of the bounding boxes are equal since the network requires squared input.
+ - The minimum bounding box size is 20 pixels.
+ """
+ ratioWristElbow = 0.33
+ detect_result = []
+ image_height, image_width = oriImg.shape[0:2]
+
+ keypoints = body.keypoints
+ # right hand: wrist 4, elbow 3, shoulder 2
+ # left hand: wrist 7, elbow 6, shoulder 5
+ left_shoulder = keypoints[5]
+ left_elbow = keypoints[6]
+ left_wrist = keypoints[7]
+ right_shoulder = keypoints[2]
+ right_elbow = keypoints[3]
+ right_wrist = keypoints[4]
+
+ # if any of three not detected
+ has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist))
+ has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist))
+ if not (has_left or has_right):
+ return []
+
+ hands = []
+ #left hand
+ if has_left:
+ hands.append([
+ left_shoulder.x, left_shoulder.y,
+ left_elbow.x, left_elbow.y,
+ left_wrist.x, left_wrist.y,
+ True
+ ])
+ # right hand
+ if has_right:
+ hands.append([
+ right_shoulder.x, right_shoulder.y,
+ right_elbow.x, right_elbow.y,
+ right_wrist.x, right_wrist.y,
+ False
+ ])
+
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+ x = x3 + ratioWristElbow * (x3 - x2)
+ y = y3 + ratioWristElbow * (y3 - y2)
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+ # x-y refers to the center --> offset to topLeft point
+ # handRectangle.x -= handRectangle.width / 2.f;
+ # handRectangle.y -= handRectangle.height / 2.f;
+ x -= width / 2
+ y -= width / 2 # width = height
+ # overflow the image
+ if x < 0: x = 0
+ if y < 0: y = 0
+ width1 = width
+ width2 = width
+ if x + width > image_width: width1 = image_width - x
+ if y + width > image_height: width2 = image_height - y
+ width = min(width1, width2)
+ # the max hand box value is 20 pixels
+ if width >= 20:
+ detect_result.append((int(x), int(y), int(width), is_left))
+
+ '''
+ return value: [[x, y, w, True if left hand else False]].
+ width=height since the network require squared input.
+ x, y is the coordinate of top left
+ '''
+ return detect_result
+
+
+# Written by Lvmin
+def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]:
+ """
+ Detect the face in the input body pose keypoints and calculate the bounding box for the face.
+
+ Args:
+ body (BodyResult): A BodyResult object containing the detected body pose keypoints.
+ oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
+
+ Returns:
+ Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the
+ bounding box and the width (height) of the bounding box, or None if the
+ face is not detected or the bounding box width is less than 20 pixels.
+
+ Notes:
+ - The width and height of the bounding box are equal.
+ - The minimum bounding box size is 20 pixels.
+ """
+ # left right eye ear 14 15 16 17
+ image_height, image_width = oriImg.shape[0:2]
+
+ keypoints = body.keypoints
+ head = keypoints[0]
+ left_eye = keypoints[14]
+ right_eye = keypoints[15]
+ left_ear = keypoints[16]
+ right_ear = keypoints[17]
+
+ if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)):
+ return None
+
+ width = 0.0
+ x0, y0 = head.x, head.y
+
+ if left_eye is not None:
+ x1, y1 = left_eye.x, left_eye.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 3.0)
+
+ if right_eye is not None:
+ x1, y1 = right_eye.x, right_eye.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 3.0)
+
+ if left_ear is not None:
+ x1, y1 = left_ear.x, left_ear.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 1.5)
+
+ if right_ear is not None:
+ x1, y1 = right_ear.x, right_ear.y
+ d = max(abs(x0 - x1), abs(y0 - y1))
+ width = max(width, d * 1.5)
+
+ x, y = x0, y0
+
+ x -= width
+ y -= width
+
+ if x < 0:
+ x = 0
+
+ if y < 0:
+ y = 0
+
+ width1 = width * 2
+ width2 = width * 2
+
+ if x + width > image_width:
+ width1 = image_width - x
+
+ if y + width > image_height:
+ width2 = image_height - y
+
+ width = min(width1, width2)
+
+ if width >= 20:
+ return int(x), int(y), int(width)
+ else:
+ return None
+
+
+# get max index of 2d array
+def npmax(array):
+ arrayindex = array.argmax(1)
+ arrayvalue = array.max(1)
+ i = arrayvalue.argmax()
+ j = arrayindex[i]
+ return i, j
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/pidinet/LICENSE b/lib/model_zoo/controlnet_annotator/pidinet/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..913b6cf92c19d37b6ee4f7bc99c65f655e7f840c
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/pidinet/LICENSE
@@ -0,0 +1,21 @@
+It is just for research purpose, and commercial use should be contacted with authors first.
+
+Copyright (c) 2021 Zhuo Su
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/pidinet/__init__.py b/lib/model_zoo/controlnet_annotator/pidinet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d427b0688664fe76f4321318ea99b374d58c64f
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/pidinet/__init__.py
@@ -0,0 +1,101 @@
+import os
+import torch
+import numpy as np
+from einops import rearrange
+from .model import pidinet
+
+models_path = 'pretrained/controlnet/preprocess'
+
+netNetwork = None
+remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth"
+modeldir = os.path.join(models_path, "pidinet")
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+
+def safe_step(x, step=2):
+ y = x.astype(np.float32) * float(step + 1)
+ y = y.astype(np.int32).astype(np.float32) / float(step)
+ return y
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+ Returns:
+ str: The path to the downloaded file.
+ """
+ from torch.hub import download_url_to_file, get_dir
+ from urllib.parse import urlparse
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+def load_state_dict(ckpt_path, location='cpu'):
+ def get_state_dict(d):
+ return d.get('state_dict', d)
+
+ _, extension = os.path.splitext(ckpt_path)
+ if extension.lower() == ".safetensors":
+ import safetensors.torch
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+ else:
+ state_dict = get_state_dict(torch.load(
+ ckpt_path, map_location=torch.device(location)))
+ state_dict = get_state_dict(state_dict)
+ print(f'Loaded state_dict from [{ckpt_path}]')
+ return state_dict
+
+def apply_pidinet(input_image, is_safe=False, apply_fliter=False, device='cpu'):
+ global netNetwork
+ if netNetwork is None:
+ modelpath = os.path.join(modeldir, "table5_pidinet.pth")
+ old_modelpath = os.path.join(old_modeldir, "table5_pidinet.pth")
+ if os.path.exists(old_modelpath):
+ modelpath = old_modelpath
+ elif not os.path.exists(modelpath):
+ load_file_from_url(remote_model_path, model_dir=modeldir)
+ netNetwork = pidinet()
+ ckp = load_state_dict(modelpath)
+ netNetwork.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
+
+ netNetwork = netNetwork.to(device)
+ netNetwork.eval()
+ assert input_image.ndim == 3
+ input_image = input_image[:, :, ::-1].copy()
+ with torch.no_grad():
+ image_pidi = torch.from_numpy(input_image).float().to(device)
+ image_pidi = image_pidi / 255.0
+ image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
+ edge = netNetwork(image_pidi)[-1]
+ edge = edge.cpu().numpy()
+ if apply_fliter:
+ edge = edge > 0.5
+ if is_safe:
+ edge = safe_step(edge)
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
+
+ return edge[0][0]
+
+def unload_pid_model():
+ global netNetwork
+ if netNetwork is not None:
+ netNetwork.cpu()
\ No newline at end of file
diff --git a/lib/model_zoo/controlnet_annotator/pidinet/model.py b/lib/model_zoo/controlnet_annotator/pidinet/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7114da1344ee01c359198d005be0c97d42d49dd6
--- /dev/null
+++ b/lib/model_zoo/controlnet_annotator/pidinet/model.py
@@ -0,0 +1,680 @@
+"""
+Author: Zhuo Su, Wenzhe Liu
+Date: Feb 18, 2021
+"""
+
+import math
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+nets = {
+ 'baseline': {
+ 'layer0': 'cv',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'c-v15': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'a-v15': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'r-v15': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cvvv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'avvv4': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'rvvv4': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cccv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cv',
+ },
+ 'aaav4': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'cv',
+ },
+ 'rrrv4': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+ 'c16': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cd',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cd',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cd',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cd',
+ },
+ 'a16': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'ad',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'ad',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'ad',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'ad',
+ },
+ 'r16': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'rd',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'rd',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'rd',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'rd',
+ },
+ 'carv4': {
+ 'layer0': 'cd',
+ 'layer1': 'ad',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'ad',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'ad',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'ad',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+ }
+
+def createConvFunc(op_type):
+ assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
+ if op_type == 'cv':
+ return F.conv2d
+
+ if op_type == 'cd':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
+ assert padding == dilation, 'padding for cd_conv set wrong'
+
+ weights_c = weights.sum(dim=[2, 3], keepdim=True)
+ yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
+ y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y - yc
+ return func
+ elif op_type == 'ad':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
+ assert padding == dilation, 'padding for ad_conv set wrong'
+
+ shape = weights.shape
+ weights = weights.view(shape[0], shape[1], -1)
+ weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
+ y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y
+ return func
+ elif op_type == 'rd':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
+ padding = 2 * dilation
+
+ shape = weights.shape
+ if weights.is_cuda:
+ buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
+ else:
+ buffer = torch.zeros(shape[0], shape[1], 5 * 5)
+ weights = weights.view(shape[0], shape[1], -1)
+ buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
+ buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
+ buffer[:, :, 12] = 0
+ buffer = buffer.view(shape[0], shape[1], 5, 5)
+ y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y
+ return func
+ else:
+ print('impossible to be here unless you force that')
+ return None
+
+class Conv2d(nn.Module):
+ def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
+ super(Conv2d, self).__init__()
+ if in_channels % groups != 0:
+ raise ValueError('in_channels must be divisible by groups')
+ if out_channels % groups != 0:
+ raise ValueError('out_channels must be divisible by groups')
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+ self.pdc = pdc
+
+ def reset_parameters(self):
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input):
+
+ return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+class CSAM(nn.Module):
+ """
+ Compact Spatial Attention Module
+ """
+ def __init__(self, channels):
+ super(CSAM, self).__init__()
+
+ mid_channels = 4
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
+ self.sigmoid = nn.Sigmoid()
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ y = self.relu1(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = self.sigmoid(y)
+
+ return x * y
+
+class CDCM(nn.Module):
+ """
+ Compact Dilation Convolution based Module
+ """
+ def __init__(self, in_channels, out_channels):
+ super(CDCM, self).__init__()
+
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
+ self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
+ self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
+ self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
+ self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ x = self.relu1(x)
+ x = self.conv1(x)
+ x1 = self.conv2_1(x)
+ x2 = self.conv2_2(x)
+ x3 = self.conv2_3(x)
+ x4 = self.conv2_4(x)
+ return x1 + x2 + x3 + x4
+
+
+class MapReduce(nn.Module):
+ """
+ Reduce feature maps into a single edge map
+ """
+ def __init__(self, channels):
+ super(MapReduce, self).__init__()
+ self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class PDCBlock(nn.Module):
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock, self).__init__()
+ self.stride=stride
+
+ self.stride=stride
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
+ self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+class PDCBlock_converted(nn.Module):
+ """
+ CPDC, APDC can be converted to vanilla 3x3 convolution
+ RPDC can be converted to vanilla 5x5 convolution
+ """
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock_converted, self).__init__()
+ self.stride=stride
+
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
+ if pdc == 'rd':
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
+ else:
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+class PiDiNet(nn.Module):
+ def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
+ super(PiDiNet, self).__init__()
+ self.sa = sa
+ if dil is not None:
+ assert isinstance(dil, int), 'dil should be an int'
+ self.dil = dil
+
+ self.fuseplanes = []
+
+ self.inplane = inplane
+ if convert:
+ if pdcs[0] == 'rd':
+ init_kernel_size = 5
+ init_padding = 2
+ else:
+ init_kernel_size = 3
+ init_padding = 1
+ self.init_block = nn.Conv2d(3, self.inplane,
+ kernel_size=init_kernel_size, padding=init_padding, bias=False)
+ block_class = PDCBlock_converted
+ else:
+ self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
+ block_class = PDCBlock
+
+ self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
+ self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
+ self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
+ self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
+ self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
+ self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 2C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
+ self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
+ self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
+ self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
+ self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
+ self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
+ self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.conv_reduces = nn.ModuleList()
+ if self.sa and self.dil is not None:
+ self.attentions = nn.ModuleList()
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.attentions.append(CSAM(self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ elif self.sa:
+ self.attentions = nn.ModuleList()
+ for i in range(4):
+ self.attentions.append(CSAM(self.fuseplanes[i]))
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+ elif self.dil is not None:
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ else:
+ for i in range(4):
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+
+ self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
+ nn.init.constant_(self.classifier.weight, 0.25)
+ nn.init.constant_(self.classifier.bias, 0)
+
+ # print('initialization done')
+
+ def get_weights(self):
+ conv_weights = []
+ bn_weights = []
+ relu_weights = []
+ for pname, p in self.named_parameters():
+ if 'bn' in pname:
+ bn_weights.append(p)
+ elif 'relu' in pname:
+ relu_weights.append(p)
+ else:
+ conv_weights.append(p)
+
+ return conv_weights, bn_weights, relu_weights
+
+ def forward(self, x):
+ H, W = x.size()[2:]
+
+ x = self.init_block(x)
+
+ x1 = self.block1_1(x)
+ x1 = self.block1_2(x1)
+ x1 = self.block1_3(x1)
+
+ x2 = self.block2_1(x1)
+ x2 = self.block2_2(x2)
+ x2 = self.block2_3(x2)
+ x2 = self.block2_4(x2)
+
+ x3 = self.block3_1(x2)
+ x3 = self.block3_2(x3)
+ x3 = self.block3_3(x3)
+ x3 = self.block3_4(x3)
+
+ x4 = self.block4_1(x3)
+ x4 = self.block4_2(x4)
+ x4 = self.block4_3(x4)
+ x4 = self.block4_4(x4)
+
+ x_fuses = []
+ if self.sa and self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](self.dilations[i](xi)))
+ elif self.sa:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](xi))
+ elif self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.dilations[i](xi))
+ else:
+ x_fuses = [x1, x2, x3, x4]
+
+ e1 = self.conv_reduces[0](x_fuses[0])
+ e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
+
+ e2 = self.conv_reduces[1](x_fuses[1])
+ e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
+
+ e3 = self.conv_reduces[2](x_fuses[2])
+ e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
+
+ e4 = self.conv_reduces[3](x_fuses[3])
+ e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
+
+ outputs = [e1, e2, e3, e4]
+
+ output = self.classifier(torch.cat(outputs, dim=1))
+ #if not self.training:
+ # return torch.sigmoid(output)
+
+ outputs.append(output)
+ outputs = [torch.sigmoid(r) for r in outputs]
+ return outputs
+
+def config_model(model):
+ model_options = list(nets.keys())
+ assert model in model_options, \
+ 'unrecognized model, please choose from %s' % str(model_options)
+
+ # print(str(nets[model]))
+
+ pdcs = []
+ for i in range(16):
+ layer_name = 'layer%d' % i
+ op = nets[model][layer_name]
+ pdcs.append(createConvFunc(op))
+
+ return pdcs
+
+def pidinet():
+ pdcs = config_model('carv4')
+ dil = 24 #if args.dil else None
+ return PiDiNet(60, pdcs, dil=dil, sa=True)
+
+
+if __name__ == '__main__':
+ model = pidinet()
+ ckp = torch.load('table5_pidinet.pth')['state_dict']
+ model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
+ im = cv2.imread('examples/test_my/cat_v4.png')
+ im = img2tensor(im).unsqueeze(0)/255.
+ res = model(im)[-1]
+ res = res>0.5
+ res = res.float()
+ res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8)
+ print(res.shape)
+ cv2.imwrite('edge.png', res)
\ No newline at end of file
diff --git a/lib/model_zoo/ddim.py b/lib/model_zoo/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..66f0dc0dd1a33bb254d561ece5429c1ac08754f0
--- /dev/null
+++ b/lib/model_zoo/ddim.py
@@ -0,0 +1,299 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize,
+ num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
+ verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+ alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ steps,
+ shape,
+ x_info,
+ c_info,
+ eta=0.,
+ temperature=1.,
+ noise_dropout=0.,
+ verbose=True,
+ log_every_t=100,):
+
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
+ print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
+ samples, intermediates = self.ddim_sampling(
+ shape,
+ x_info=x_info,
+ c_info=c_info,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ log_every_t=log_every_t,)
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self,
+ shape,
+ x_info,
+ c_info,
+ noise_dropout=0.,
+ temperature=1.,
+ log_every_t=100,):
+
+ device = self.model.device
+ dtype = c_info['conditioning'].dtype
+ bs = shape[0]
+ timesteps = self.ddim_timesteps
+ if ('xt' in x_info) and (x_info['xt'] is not None):
+ xt = x_info['xt'].astype(dtype).to(device)
+ x_info['x'] = xt
+ elif ('x0' in x_info) and (x_info['x0'] is not None):
+ x0 = x_info['x0'].type(dtype).to(device)
+ ts = timesteps[x_info['x0_forward_timesteps']].repeat(bs)
+ ts = torch.Tensor(ts).long().to(device)
+ timesteps = timesteps[:x_info['x0_forward_timesteps']]
+ x0_nz = self.model.q_sample(x0, ts)
+ x_info['x'] = x0_nz
+ else:
+ x_info['x'] = torch.randn(shape, device=device, dtype=dtype)
+
+ intermediates = {'pred_xt': [], 'pred_x0': []}
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((bs,), step, device=device, dtype=torch.long)
+
+ outs = self.p_sample_ddim(
+ x_info, c_info, ts, index,
+ noise_dropout=noise_dropout,
+ temperature=temperature,)
+ pred_xt, pred_x0 = outs
+ x_info['x'] = pred_xt
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['pred_xt'].append(pred_xt)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return pred_xt, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x_info, c_info, t, index,
+ repeat_noise=False,
+ use_original_steps=False,
+ noise_dropout=0.,
+ temperature=1.,):
+
+ x = x_info['x']
+ unconditional_guidance_scale = c_info['unconditional_guidance_scale']
+
+ b, *_, device = *x.shape, x.device
+ if (unconditional_guidance_scale == 1.) or (c_info['unconditional_conditioning'] is None):
+ c_info['c'] = c_info['conditioning']
+ e_t = self.model.apply_model(x_info, t, c_info)
+ e_t = e_t * unconditional_guidance_scale
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([c_info['unconditional_conditioning'], c_info['conditioning']])
+ x_info['x'] = x_in
+ c_info['c'] = c_in
+ e_t_uncond, e_t = self.model.apply_model(x_info, t_in, c_info).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+
+ extended_shape = [b] + [1]*(len(e_t.shape)-1)
+ a_t = torch.full(extended_shape, alphas[index], device=device, dtype=x.dtype)
+ a_prev = torch.full(extended_shape, alphas_prev[index], device=device, dtype=x.dtype)
+ sigma_t = torch.full(extended_shape, sigmas[index], device=device, dtype=x.dtype)
+ sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index], device=device, dtype=x.dtype)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def sample_multicontext(self,
+ steps,
+ shape,
+ x_info,
+ c_info_list,
+ eta=0.,
+ temperature=1.,
+ noise_dropout=0.,
+ verbose=True,
+ log_every_t=100,):
+
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
+ print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
+ samples, intermediates = self.ddim_sampling_multicontext(
+ shape,
+ x_info=x_info,
+ c_info_list=c_info_list,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ log_every_t=log_every_t,)
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling_multicontext(self,
+ shape,
+ x_info,
+ c_info_list,
+ noise_dropout=0.,
+ temperature=1.,
+ log_every_t=100,):
+
+ device = self.model.device
+ dtype = c_info_list[0]['conditioning'].dtype
+ bs = shape[0]
+ timesteps = self.ddim_timesteps
+ if ('xt' in x_info) and (x_info['xt'] is not None):
+ xt = x_info['xt'].astype(dtype).to(device)
+ x_info['x'] = xt
+ elif ('x0' in x_info) and (x_info['x0'] is not None):
+ x0 = x_info['x0'].type(dtype).to(device)
+ ts = timesteps[x_info['x0_forward_timesteps']].repeat(bs)
+ ts = torch.Tensor(ts).long().to(device)
+ timesteps = timesteps[:x_info['x0_forward_timesteps']]
+ x0_nz = self.model.q_sample(x0, ts)
+ x_info['x'] = x0_nz
+ else:
+ x_info['x'] = torch.randn(shape, device=device, dtype=dtype)
+
+ intermediates = {'pred_xt': [], 'pred_x0': []}
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((bs,), step, device=device, dtype=torch.long)
+
+ outs = self.p_sample_ddim_multicontext(
+ x_info, c_info_list, ts, index,
+ noise_dropout=noise_dropout,
+ temperature=temperature,)
+ pred_xt, pred_x0 = outs
+ x_info['x'] = pred_xt
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['pred_xt'].append(pred_xt)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return pred_xt, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim_multicontext(
+ self, x_info, c_info_list, t, index,
+ repeat_noise=False,
+ use_original_steps=False,
+ noise_dropout=0.,
+ temperature=1.,):
+
+ x = x_info['x']
+ b, *_, device = *x.shape, x.device
+ unconditional_guidance_scale = None
+
+ for c_info in c_info_list:
+ if unconditional_guidance_scale is None:
+ unconditional_guidance_scale = c_info['unconditional_guidance_scale']
+ else:
+ assert unconditional_guidance_scale==c_info['unconditional_guidance_scale'], \
+ "A different unconditional guidance scale between different context is not allowed!"
+
+ if unconditional_guidance_scale == 1.:
+ c_info['c'] = c_info['conditioning']
+
+ else:
+ c_in = torch.cat([c_info['unconditional_conditioning'], c_info['conditioning']])
+ c_info['c'] = c_in
+
+ if unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model_multicontext(x_info, t, c_info_list)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ x_info['x'] = x_in
+ e_t_uncond, e_t = self.model.apply_model_multicontext(x_info, t_in, c_info_list).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+
+ extended_shape = [b] + [1]*(len(e_t.shape)-1)
+ a_t = torch.full(extended_shape, alphas[index], device=device, dtype=x.dtype)
+ a_prev = torch.full(extended_shape, alphas_prev[index], device=device, dtype=x.dtype)
+ sigma_t = torch.full(extended_shape, sigmas[index], device=device, dtype=x.dtype)
+ sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index], device=device, dtype=x.dtype)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
diff --git a/lib/model_zoo/diffusion_utils.py b/lib/model_zoo/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b28b42dc6d2933d4a6159e973f70dc721f19701d
--- /dev/null
+++ b/lib/model_zoo/diffusion_utils.py
@@ -0,0 +1,250 @@
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ # return super().forward(x.float()).type(x.dtype)
+ return super().forward(x)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+def noise_like(x, repeat=False):
+ noise = torch.randn_like(x)
+ if repeat:
+ bs = x.shape[0]
+ noise = noise[0:1].repeat(bs, *((1,) * (len(x.shape) - 1)))
+ return noise
+
+##########################
+# inherit from ldm.utils #
+##########################
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
diff --git a/lib/model_zoo/distributions.py b/lib/model_zoo/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/lib/model_zoo/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/lib/model_zoo/ema.py b/lib/model_zoo/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d61e90eadb4701c7c38d9ed63e4fca7afb78d9
--- /dev/null
+++ b/lib/model_zoo/ema.py
@@ -0,0 +1,75 @@
+import torch
+from torch import nn
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_updates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_updates
+ else torch.tensor(-1,dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ #remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.','')
+ self.m_name2s_name.update({name:s_name})
+ self.register_buffer(s_name,p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/lib/model_zoo/openaimodel.py b/lib/model_zoo/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1cb6f1333ff941ec0a9ddb81a90f077b441f05c
--- /dev/null
+++ b/lib/model_zoo/openaimodel.py
@@ -0,0 +1,2975 @@
+from abc import abstractmethod
+from functools import partial
+import math
+from typing import Iterable
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .diffusion_utils import \
+ checkpoint, conv_nd, linear, avg_pool_nd, \
+ zero_module, normalization, timestep_embedding
+
+from .attention import SpatialTransformer
+
+from lib.model_zoo.common.get_model import get_model, register
+
+symbol = 'openai'
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+@register('openai_unet')
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ #self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.") # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if disable_self_attentions is not None:
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if disable_self_attentions is not None:
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+#######################
+# Unet with self-attn #
+#######################
+
+from .attention import SpatialTransformerNoContext
+
+@register('openai_unet_nocontext')
+class UNetModelNoContext(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ num_attention_blocks=None, ):
+
+ super().__init__()
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ #self.num_res_blocks = num_res_blocks
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.") # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+
+ if (num_attention_blocks is None) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformerNoContext(
+ ch, num_heads, dim_head, depth=transformer_depth
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformerNoContext( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+
+ if (num_attention_blocks is None) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformerNoContext(
+ ch, num_heads, dim_head, depth=transformer_depth,
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def forward(self, x, timesteps):
+ assert self.num_classes is None, \
+ "not supported"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+@register('openai_unet_nocontext_noatt')
+class UNetModelNoContextNoAtt(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ n_embed=None,):
+
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ #self.num_res_blocks = num_res_blocks
+
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def forward(self, x, timesteps):
+ assert self.num_classes is None, \
+ "not supported"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+@register('openai_unet_nocontext_noatt_decoderonly')
+class UNetModelNoContextNoAttDecoderOnly(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ model_channels,
+ num_res_blocks,
+ dropout=0,
+ channel_mult=(4, 2, 1),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ n_embed=None,):
+
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ #self.num_res_blocks = num_res_blocks
+
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self._feature_size = model_channels
+
+ ch = model_channels * self.channel_mult[0]
+ self.output_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, ch, 3, padding=1)
+ )
+ ]
+ )
+
+ for level, mult in enumerate(channel_mult):
+ for i in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if level != len(channel_mult)-1 and (i == self.num_res_blocks[level]-1):
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def forward(self, x, timesteps):
+ assert self.num_classes is None, \
+ "not supported"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ h = x.type(self.dtype)
+ for module in self.output_blocks:
+ h = module(h, emb)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+#########################
+# Double Attention Unet #
+#########################
+
+from .attention import DualSpatialTransformer
+
+class TimestepEmbedSequentialExtended(nn.Sequential, TimestepBlock):
+ def forward(self, x, emb, context=None, which_attn=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ elif isinstance(layer, DualSpatialTransformer):
+ x = layer(x, context, which=which_attn)
+ else:
+ x = layer(x)
+ return x
+
+@register('openai_unet_dual_context')
+class UNetModelDualContext(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None, ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ #self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.") # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequentialExtended(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if disable_self_attentions is not None:
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else DualSpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequentialExtended(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequentialExtended(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequentialExtended(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else DualSpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if disable_self_attentions is not None:
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else DualSpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequentialExtended(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def forward(self, x, timesteps=None, context=None, y=None, which_attn=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ t_emb = t_emb.to(context.dtype)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context, which_attn=which_attn)
+ hs.append(h)
+ h = self.middle_block(h, emb, context, which_attn=which_attn)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context, which_attn=which_attn)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+###########
+# VD Unet #
+###########
+
+from functools import partial
+
+@register('openai_unet_2d')
+class UNetModel2D(nn.Module):
+ def __init__(self,
+ input_channels,
+ model_channels,
+ output_channels,
+ context_dim=768,
+ num_noattn_blocks=(2, 2, 2, 2),
+ channel_mult=(1, 2, 4, 8),
+ with_attn=[True, True, True, False],
+ num_heads=8,
+ use_checkpoint=True, ):
+
+ super().__init__()
+
+ ResBlockPreset = partial(
+ ResBlock, dropout=0, dims=2, use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=False)
+
+ self.input_channels = input_channels
+ self.model_channels = model_channels
+ self.num_noattn_blocks = num_noattn_blocks
+ self.channel_mult = channel_mult
+ self.num_heads = num_heads
+
+ ##################
+ # Time embedding #
+ ##################
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),)
+
+ ################
+ # input_blocks #
+ ################
+ current_channel = model_channels
+ input_blocks = [
+ TimestepEmbedSequential(
+ nn.Conv2d(input_channels, model_channels, 3, padding=1, bias=True))]
+ input_block_channels = [current_channel]
+
+ for level_idx, mult in enumerate(channel_mult):
+ for _ in range(self.num_noattn_blocks[level_idx]):
+ layers = [
+ ResBlockPreset(
+ current_channel, time_embed_dim,
+ out_channels = mult * model_channels,)]
+
+ current_channel = mult * model_channels
+ dim_head = current_channel // num_heads
+ if with_attn[level_idx]:
+ layers += [
+ SpatialTransformer(
+ current_channel, num_heads, dim_head,
+ depth=1, context_dim=context_dim, )]
+
+ input_blocks += [TimestepEmbedSequential(*layers)]
+ input_block_channels.append(current_channel)
+
+ if level_idx != len(channel_mult) - 1:
+ input_blocks += [
+ TimestepEmbedSequential(
+ Downsample(
+ current_channel, use_conv=True,
+ dims=2, out_channels=current_channel,))]
+ input_block_channels.append(current_channel)
+
+ self.input_blocks = nn.ModuleList(input_blocks)
+
+ #################
+ # middle_blocks #
+ #################
+ middle_block = [
+ ResBlockPreset(
+ current_channel, time_embed_dim,),
+ SpatialTransformer(
+ current_channel, num_heads, dim_head,
+ depth=1, context_dim=context_dim, ),
+ ResBlockPreset(
+ current_channel, time_embed_dim,),]
+ self.middle_block = TimestepEmbedSequential(*middle_block)
+
+ #################
+ # output_blocks #
+ #################
+ output_blocks = []
+ for level_idx, mult in list(enumerate(channel_mult))[::-1]:
+ for block_idx in range(self.num_noattn_blocks[level_idx] + 1):
+ extra_channel = input_block_channels.pop()
+ layers = [
+ ResBlockPreset(
+ current_channel + extra_channel,
+ time_embed_dim,
+ out_channels = model_channels * mult,) ]
+
+ current_channel = model_channels * mult
+ dim_head = current_channel // num_heads
+ if with_attn[level_idx]:
+ layers += [
+ SpatialTransformer(
+ current_channel, num_heads, dim_head,
+ depth=1, context_dim=context_dim,)]
+
+ if level_idx!=0 and block_idx==self.num_noattn_blocks[level_idx]:
+ layers += [
+ Upsample(
+ current_channel, use_conv=True,
+ dims=2, out_channels=current_channel)]
+
+ output_blocks += [TimestepEmbedSequential(*layers)]
+
+ self.output_blocks = nn.ModuleList(output_blocks)
+
+ self.out = nn.Sequential(
+ normalization(current_channel),
+ nn.SiLU(),
+ zero_module(nn.Conv2d(model_channels, output_channels, 3, padding=1)),)
+
+ def forward(self, x, timesteps=None, context=None):
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ return self.out(h)
+
+class FCBlock(TimestepBlock):
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_checkpoint = use_checkpoint
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ nn.Conv2d(channels, self.out_channels, 1, padding=0),)
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(emb_channels, self.out_channels,),)
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(nn.Conv2d(self.out_channels, self.out_channels, 1, padding=0)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ else:
+ self.skip_connection = nn.Conv2d(channels, self.out_channels, 1, padding=0)
+
+ def forward(self, x, emb):
+ if len(x.shape) == 2:
+ x = x[:, :, None, None]
+ elif len(x.shape) == 4:
+ pass
+ else:
+ raise ValueError
+ y = checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint)
+ if len(x.shape) == 2:
+ return y[:, :, 0, 0]
+ elif len(x.shape) == 4:
+ return y
+
+ def _forward(self, x, emb):
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+@register('openai_unet_0d')
+class UNetModel0D(nn.Module):
+ def __init__(self,
+ input_channels,
+ model_channels,
+ output_channels,
+ context_dim=768,
+ num_noattn_blocks=(2, 2, 2, 2),
+ channel_mult=(1, 2, 4, 8),
+ with_attn=[True, True, True, False],
+ num_heads=8,
+ use_checkpoint=True, ):
+
+ super().__init__()
+
+ FCBlockPreset = partial(FCBlock, dropout=0, use_checkpoint=use_checkpoint)
+
+ self.input_channels = input_channels
+ self.model_channels = model_channels
+ self.num_noattn_blocks = num_noattn_blocks
+ self.channel_mult = channel_mult
+ self.num_heads = num_heads
+
+ ##################
+ # Time embedding #
+ ##################
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),)
+
+ ################
+ # input_blocks #
+ ################
+ current_channel = model_channels
+ input_blocks = [
+ TimestepEmbedSequential(
+ nn.Conv2d(input_channels, model_channels, 1, padding=0, bias=True))]
+ input_block_channels = [current_channel]
+
+ for level_idx, mult in enumerate(channel_mult):
+ for _ in range(self.num_noattn_blocks[level_idx]):
+ layers = [
+ FCBlockPreset(
+ current_channel, time_embed_dim,
+ out_channels = mult * model_channels,)]
+
+ current_channel = mult * model_channels
+ dim_head = current_channel // num_heads
+ if with_attn[level_idx]:
+ layers += [
+ SpatialTransformer(
+ current_channel, num_heads, dim_head,
+ depth=1, context_dim=context_dim, )]
+
+ input_blocks += [TimestepEmbedSequential(*layers)]
+ input_block_channels.append(current_channel)
+
+ if level_idx != len(channel_mult) - 1:
+ input_blocks += [
+ TimestepEmbedSequential(
+ Downsample(
+ current_channel, use_conv=True,
+ dims=2, out_channels=current_channel,))]
+ input_block_channels.append(current_channel)
+
+ self.input_blocks = nn.ModuleList(input_blocks)
+
+ #################
+ # middle_blocks #
+ #################
+ middle_block = [
+ FCBlockPreset(
+ current_channel, time_embed_dim,),
+ SpatialTransformer(
+ current_channel, num_heads, dim_head,
+ depth=1, context_dim=context_dim, ),
+ FCBlockPreset(
+ current_channel, time_embed_dim,),]
+ self.middle_block = TimestepEmbedSequential(*middle_block)
+
+ #################
+ # output_blocks #
+ #################
+ output_blocks = []
+ for level_idx, mult in list(enumerate(channel_mult))[::-1]:
+ for block_idx in range(self.num_noattn_blocks[level_idx] + 1):
+ extra_channel = input_block_channels.pop()
+ layers = [
+ FCBlockPreset(
+ current_channel + extra_channel,
+ time_embed_dim,
+ out_channels = model_channels * mult,) ]
+
+ current_channel = model_channels * mult
+ dim_head = current_channel // num_heads
+ if with_attn[level_idx]:
+ layers += [
+ SpatialTransformer(
+ current_channel, num_heads, dim_head,
+ depth=1, context_dim=context_dim,)]
+
+ if level_idx!=0 and block_idx==self.num_noattn_blocks[level_idx]:
+ layers += [
+ nn.Conv2d(current_channel, current_channel, 1, padding=0)]
+
+ output_blocks += [TimestepEmbedSequential(*layers)]
+
+ self.output_blocks = nn.ModuleList(output_blocks)
+
+ self.out = nn.Sequential(
+ normalization(current_channel),
+ nn.SiLU(),
+ zero_module(nn.Conv2d(model_channels, output_channels, 1, padding=0)),)
+
+ def forward(self, x, timesteps=None, context=None):
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ return self.out(h)
+
+class Linear_MultiDim(nn.Linear):
+ def __init__(self, in_features, out_features, *args, **kwargs):
+
+ in_features = [in_features] if isinstance(in_features, int) else list(in_features)
+ out_features = [out_features] if isinstance(out_features, int) else list(out_features)
+ self.in_features_multidim = in_features
+ self.out_features_multidim = out_features
+ super().__init__(
+ np.array(in_features).prod(),
+ np.array(out_features).prod(),
+ *args, **kwargs)
+
+ def forward(self, x):
+ shape = x.shape
+ n = len(shape) - len(self.in_features_multidim)
+ x = x.view(*shape[:n], self.in_features)
+ y = super().forward(x)
+ y = y.view(*shape[:n], *self.out_features_multidim)
+ return y
+
+class FCBlock_MultiDim(FCBlock):
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_checkpoint=False,):
+ channels = [channels] if isinstance(channels, int) else list(channels)
+ channels_all = np.array(channels).prod()
+ self.channels_multidim = channels
+
+ if out_channels is not None:
+ out_channels = [out_channels] if isinstance(out_channels, int) else list(out_channels)
+ out_channels_all = np.array(out_channels).prod()
+ self.out_channels_multidim = out_channels
+ else:
+ out_channels_all = channels_all
+ self.out_channels_multidim = self.channels_multidim
+
+ self.channels = channels
+ super().__init__(
+ channels = channels_all,
+ emb_channels = emb_channels,
+ dropout = dropout,
+ out_channels = out_channels_all,
+ use_checkpoint = use_checkpoint,)
+
+ def forward(self, x, emb):
+ shape = x.shape
+ n = len(self.channels_multidim)
+ x = x.view(*shape[0:-n], self.channels, 1, 1)
+ x = x.view(-1, self.channels, 1, 1)
+ y = checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint)
+ y = y.view(*shape[0:-n], -1)
+ y = y.view(*shape[0:-n], *self.out_channels_multidim)
+ return y
+
+@register('openai_unet_0dmd')
+class UNetModel0D_MultiDim(nn.Module):
+ def __init__(self,
+ input_channels,
+ model_channels,
+ output_channels,
+ context_dim=768,
+ num_noattn_blocks=(2, 2, 2, 2),
+ channel_mult=(1, 2, 4, 8),
+ second_dim=(4, 4, 4, 4),
+ with_attn=[True, True, True, False],
+ num_heads=8,
+ use_checkpoint=True, ):
+
+ super().__init__()
+
+ FCBlockPreset = partial(FCBlock_MultiDim, dropout=0, use_checkpoint=use_checkpoint)
+
+ self.input_channels = input_channels
+ self.model_channels = model_channels
+ self.num_noattn_blocks = num_noattn_blocks
+ self.channel_mult = channel_mult
+ self.second_dim = second_dim
+ self.num_heads = num_heads
+
+ ##################
+ # Time embedding #
+ ##################
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),)
+
+ ################
+ # input_blocks #
+ ################
+ sdim = second_dim[0]
+ current_channel = [model_channels, sdim, 1]
+ input_blocks = [
+ TimestepEmbedSequential(
+ Linear_MultiDim([input_channels, 1, 1], current_channel, bias=True))]
+ input_block_channels = [current_channel]
+
+ for level_idx, (mult, sdim) in enumerate(zip(channel_mult, second_dim)):
+ for _ in range(self.num_noattn_blocks[level_idx]):
+ layers = [
+ FCBlockPreset(
+ current_channel,
+ time_embed_dim,
+ out_channels = [mult*model_channels, sdim, 1],)]
+
+ current_channel = [mult*model_channels, sdim, 1]
+ dim_head = current_channel[0] // num_heads
+ if with_attn[level_idx]:
+ layers += [
+ SpatialTransformer(
+ current_channel[0], num_heads, dim_head,
+ depth=1, context_dim=context_dim, )]
+
+ input_blocks += [TimestepEmbedSequential(*layers)]
+ input_block_channels.append(current_channel)
+
+ if level_idx != len(channel_mult) - 1:
+ input_blocks += [
+ TimestepEmbedSequential(
+ Linear_MultiDim(current_channel, current_channel, bias=True, ))]
+ input_block_channels.append(current_channel)
+
+ self.input_blocks = nn.ModuleList(input_blocks)
+
+ #################
+ # middle_blocks #
+ #################
+ middle_block = [
+ FCBlockPreset(
+ current_channel, time_embed_dim, ),
+ SpatialTransformer(
+ current_channel[0], num_heads, dim_head,
+ depth=1, context_dim=context_dim, ),
+ FCBlockPreset(
+ current_channel, time_embed_dim, ),]
+ self.middle_block = TimestepEmbedSequential(*middle_block)
+
+ #################
+ # output_blocks #
+ #################
+ output_blocks = []
+ for level_idx, (mult, sdim) in list(enumerate(zip(channel_mult, second_dim)))[::-1]:
+ for block_idx in range(self.num_noattn_blocks[level_idx] + 1):
+ extra_channel = input_block_channels.pop()
+ layers = [
+ FCBlockPreset(
+ [current_channel[0] + extra_channel[0]] + current_channel[1:],
+ time_embed_dim,
+ out_channels = [mult*model_channels, sdim, 1], )]
+
+ current_channel = [mult*model_channels, sdim, 1]
+ dim_head = current_channel[0] // num_heads
+ if with_attn[level_idx]:
+ layers += [
+ SpatialTransformer(
+ current_channel[0], num_heads, dim_head,
+ depth=1, context_dim=context_dim,)]
+
+ if level_idx!=0 and block_idx==self.num_noattn_blocks[level_idx]:
+ layers += [
+ Linear_MultiDim(current_channel, current_channel, bias=True, )]
+
+ output_blocks += [TimestepEmbedSequential(*layers)]
+
+ self.output_blocks = nn.ModuleList(output_blocks)
+
+ self.out = nn.Sequential(
+ normalization(current_channel[0]),
+ nn.SiLU(),
+ zero_module(Linear_MultiDim(current_channel, [output_channels, 1, 1], bias=True, )),)
+
+ def forward(self, x, timesteps=None, context=None):
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ return self.out(h)
+
+@register('openai_unet_vd')
+class UNetModelVD(nn.Module):
+ def __init__(self,
+ unet_image_cfg,
+ unet_text_cfg, ):
+
+ super().__init__()
+ self.unet_image = get_model()(unet_image_cfg)
+ self.unet_text = get_model()(unet_text_cfg)
+ self.time_embed = self.unet_image.time_embed
+ del self.unet_image.time_embed
+ del self.unet_text.time_embed
+
+ self.model_channels = self.unet_image.model_channels
+
+ def forward(self, x, timesteps, context, xtype='image', ctype='prompt'):
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb.to(x.dtype))
+
+ if xtype == 'text':
+ x = x[:, :, None, None]
+
+ h = x
+ for i_module, t_module in zip(self.unet_image.input_blocks, self.unet_text.input_blocks):
+ h = self.mixed_run(i_module, t_module, h, emb, context, xtype, ctype)
+ hs.append(h)
+ h = self.mixed_run(
+ self.unet_image.middle_block, self.unet_text.middle_block,
+ h, emb, context, xtype, ctype)
+ for i_module, t_module in zip(self.unet_image.output_blocks, self.unet_text.output_blocks):
+ h = th.cat([h, hs.pop()], dim=1)
+ h = self.mixed_run(i_module, t_module, h, emb, context, xtype, ctype)
+ if xtype == 'image':
+ return self.unet_image.out(h)
+ elif xtype == 'text':
+ return self.unet_text.out(h).squeeze(-1).squeeze(-1)
+
+ def mixed_run(self, inet, tnet, x, emb, context, xtype, ctype):
+
+ h = x
+ for ilayer, tlayer in zip(inet, tnet):
+ if isinstance(ilayer, TimestepBlock) and xtype=='image':
+ h = ilayer(h, emb)
+ elif isinstance(tlayer, TimestepBlock) and xtype=='text':
+ h = tlayer(h, emb)
+ elif isinstance(ilayer, SpatialTransformer) and ctype=='vision':
+ h = ilayer(h, context)
+ elif isinstance(ilayer, SpatialTransformer) and ctype=='prompt':
+ h = tlayer(h, context)
+ elif xtype=='image':
+ h = ilayer(h)
+ elif xtype == 'text':
+ h = tlayer(h)
+ else:
+ raise ValueError
+ return h
+
+ def forward_dc(self, x, timesteps, c0, c1, xtype, c0_type, c1_type, mixed_ratio):
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb.to(x.dtype))
+
+ if xtype == 'text':
+ x = x[:, :, None, None]
+ h = x
+ for i_module, t_module in zip(self.unet_image.input_blocks, self.unet_text.input_blocks):
+ h = self.mixed_run_dc(i_module, t_module, h, emb, c0, c1, xtype, c0_type, c1_type, mixed_ratio)
+ hs.append(h)
+ h = self.mixed_run_dc(
+ self.unet_image.middle_block, self.unet_text.middle_block,
+ h, emb, c0, c1, xtype, c0_type, c1_type, mixed_ratio)
+ for i_module, t_module in zip(self.unet_image.output_blocks, self.unet_text.output_blocks):
+ h = th.cat([h, hs.pop()], dim=1)
+ h = self.mixed_run_dc(i_module, t_module, h, emb, c0, c1, xtype, c0_type, c1_type, mixed_ratio)
+ if xtype == 'image':
+ return self.unet_image.out(h)
+ elif xtype == 'text':
+ return self.unet_text.out(h).squeeze(-1).squeeze(-1)
+
+ def mixed_run_dc(self, inet, tnet, x, emb, c0, c1, xtype, c0_type, c1_type, mixed_ratio):
+ h = x
+ for ilayer, tlayer in zip(inet, tnet):
+ if isinstance(ilayer, TimestepBlock) and xtype=='image':
+ h = ilayer(h, emb)
+ elif isinstance(tlayer, TimestepBlock) and xtype=='text':
+ h = tlayer(h, emb)
+ elif isinstance(ilayer, SpatialTransformer):
+ h0 = ilayer(h, c0)-h if c0_type=='vision' else tlayer(h, c0)-h
+ h1 = ilayer(h, c1)-h if c1_type=='vision' else tlayer(h, c1)-h
+ h = h0*mixed_ratio + h1*(1-mixed_ratio) + h
+ # h = ilayer(h, c0)
+ elif xtype=='image':
+ h = ilayer(h)
+ elif xtype == 'text':
+ h = tlayer(h)
+ else:
+ raise ValueError
+ return h
+
+################
+# VD Next Unet #
+################
+
+from functools import partial
+import copy
+
+@register('openai_unet_2d_next')
+class UNetModel2D_Next(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ context_dim,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ use_checkpoint=False,
+ num_heads=8,
+ num_head_channels=None,
+ parts = ['global', 'data', 'context']):
+
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+
+ self.attention_resolutions = attention_resolutions
+ self.context_dim = context_dim
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ assert (num_heads is None) + (num_head_channels is None) == 1, \
+ "One of num_heads or num_head_channels need to be set"
+
+ self.parts = parts if isinstance(parts, list) else [parts]
+ self.glayer_included = 'global' in self.parts
+ self.dlayer_included = 'data' in self.parts
+ self.clayer_included = 'context' in self.parts
+ self.layer_sequence_ordering = []
+
+ #################
+ # global layers #
+ #################
+
+ time_embed_dim = model_channels * 4
+ if self.glayer_included:
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ ################
+ # input layers #
+ ################
+
+ if self.dlayer_included:
+ self.data_blocks = nn.ModuleList([])
+ ResBlockDefault = partial(
+ ResBlock,
+ emb_channels=time_embed_dim,
+ dropout=dropout,
+ dims=2,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=False, )
+ else:
+ def dummy(*args, **kwargs):
+ return None
+ ResBlockDefault = dummy
+
+ if self.clayer_included:
+ self.context_blocks = nn.ModuleList([])
+ CrossAttnDefault = partial(
+ SpatialTransformer,
+ context_dim=context_dim,
+ disable_self_attn=False, )
+ else:
+ def dummy(*args, **kwargs):
+ return None
+ CrossAttnDefault = dummy
+
+ self.add_data_layer(conv_nd(2, in_channels, model_channels, 3, padding=1))
+ self.layer_sequence_ordering.append('save_hidden_feature')
+ input_block_chans = [model_channels]
+
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(self.num_res_blocks[level]):
+ layer = ResBlockDefault(
+ channels=ch, out_channels=mult*model_channels,)
+ self.add_data_layer(layer)
+ ch = mult * model_channels
+
+ if (ds in attention_resolutions):
+ d_head, n_heads = self.get_d_head_n_heads(ch)
+ layer = CrossAttnDefault(
+ in_channels=ch, d_head=d_head, n_heads=n_heads,)
+ self.add_context_layer(layer)
+ input_block_chans.append(ch)
+ self.layer_sequence_ordering.append('save_hidden_feature')
+
+ if level != len(channel_mult) - 1:
+ layer = Downsample(
+ ch, use_conv=True, dims=2, out_channels=ch)
+ self.add_data_layer(layer)
+ input_block_chans.append(ch)
+ self.layer_sequence_ordering.append('save_hidden_feature')
+ ds *= 2
+
+ self.i_order = copy.deepcopy(self.layer_sequence_ordering)
+ self.layer_sequence_ordering = []
+
+ #################
+ # middle layers #
+ #################
+
+ self.add_data_layer(ResBlockDefault(channels=ch))
+ d_head, n_heads = self.get_d_head_n_heads(ch)
+ self.add_context_layer(CrossAttnDefault(in_channels=ch, d_head=d_head, n_heads=n_heads))
+ self.add_data_layer(ResBlockDefault(channels=ch))
+
+ self.m_order = copy.deepcopy(self.layer_sequence_ordering)
+ self.layer_sequence_ordering = []
+
+ #################
+ # output layers #
+ #################
+
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for _ in range(self.num_res_blocks[level] + 1):
+ self.layer_sequence_ordering.append('load_hidden_feature')
+ ich = input_block_chans.pop()
+ layer = ResBlockDefault(
+ channels=ch+ich, out_channels=model_channels*mult,)
+ ch = model_channels * mult
+ self.add_data_layer(layer)
+
+ if ds in attention_resolutions:
+ d_head, n_heads = self.get_d_head_n_heads(ch)
+ layer = CrossAttnDefault(
+ in_channels=ch, d_head=d_head, n_heads=n_heads)
+ self.add_context_layer(layer)
+
+ if level != 0:
+ layer = Upsample(ch, conv_resample, dims=2, out_channels=ch)
+ self.add_data_layer(layer)
+ ds //= 2
+
+ layer = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(2, model_channels, out_channels, 3, padding=1)),
+ )
+ self.add_data_layer(layer)
+
+ self.o_order = copy.deepcopy(self.layer_sequence_ordering)
+ self.layer_order = copy.deepcopy(self.i_order + self.m_order + self.o_order)
+ del self.layer_sequence_ordering
+
+ self.parameter_group = {}
+ if self.glayer_included:
+ self.parameter_group['global'] = self.time_embed
+ if self.dlayer_included:
+ self.parameter_group['data'] = self.data_blocks
+ if self.clayer_included:
+ self.parameter_group['context'] = self.context_blocks
+
+ def get_d_head_n_heads(self, ch):
+ if self.num_head_channels is None:
+ d_head = ch // self.num_heads
+ n_heads = self.num_heads
+ else:
+ d_head = self.num_head_channels
+ n_heads = ch // self.num_head_channels
+ return d_head, n_heads
+
+ def add_data_layer(self, layer):
+ if self.dlayer_included:
+ if not isinstance(layer, (list, tuple)):
+ layer = [layer]
+ self.data_blocks.append(TimestepEmbedSequential(*layer))
+ self.layer_sequence_ordering.append('d')
+
+ def add_context_layer(self, layer):
+ if self.clayer_included:
+ if not isinstance(layer, (list, tuple)):
+ layer = [layer]
+ self.context_blocks.append(TimestepEmbedSequential(*layer))
+ self.layer_sequence_ordering.append('c')
+
+ def forward(self, x, timesteps, context):
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ d_iter = iter(self.data_blocks)
+ c_iter = iter(self.context_blocks)
+
+ h = x
+ for ltype in self.i_order:
+ if ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, context)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, context)
+ elif ltype == 'save_hidden_feature':
+ hs.append(h)
+
+ for ltype in self.m_order:
+ if ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, context)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, context)
+
+ for ltype in self.i_order:
+ if ltype == 'load_hidden_feature':
+ h = th.cat([h, hs.pop()], dim=1)
+ elif ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, context)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, context)
+ o = h
+
+ return o
+
+@register('openai_unet_0d_next')
+class UNetModel0D_Next(UNetModel2D_Next):
+ def __init__(
+ self,
+ input_channels,
+ model_channels,
+ output_channels,
+ context_dim = 788,
+ num_noattn_blocks=(2, 2, 2, 2),
+ channel_mult=(1, 2, 4, 8),
+ second_dim=(4, 4, 4, 4),
+ with_attn=[True, True, True, False],
+ num_heads=8,
+ num_head_channels=None,
+ use_checkpoint=False,
+ parts = ['global', 'data', 'context']):
+
+ super(UNetModel2D_Next, self).__init__()
+
+ self.input_channels = input_channels
+ self.model_channels = model_channels
+ self.output_channels = output_channels
+ self.num_noattn_blocks = num_noattn_blocks
+ self.channel_mult = channel_mult
+ self.second_dim = second_dim
+ self.with_attn = with_attn
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+
+ self.parts = parts if isinstance(parts, list) else [parts]
+ self.glayer_included = 'global' in self.parts
+ self.dlayer_included = 'data' in self.parts
+ self.clayer_included = 'context' in self.parts
+ self.layer_sequence_ordering = []
+
+ #################
+ # global layers #
+ #################
+
+ time_embed_dim = model_channels * 4
+ if self.glayer_included:
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ ################
+ # input layers #
+ ################
+
+ if self.dlayer_included:
+ self.data_blocks = nn.ModuleList([])
+ FCBlockDefault = partial(
+ FCBlock_MultiDim, dropout=0, use_checkpoint=use_checkpoint)
+ else:
+ def dummy(*args, **kwargs):
+ return None
+ FCBlockDefault = dummy
+
+ if self.clayer_included:
+ self.context_blocks = nn.ModuleList([])
+ CrossAttnDefault = partial(
+ SpatialTransformer,
+ context_dim=context_dim,
+ disable_self_attn=False, )
+ else:
+ def dummy(*args, **kwargs):
+ return None
+ CrossAttnDefault = dummy
+
+ sdim = second_dim[0]
+ current_channel = [model_channels, sdim, 1]
+ one_layer = Linear_MultiDim([input_channels], current_channel, bias=True)
+ self.add_data_layer(one_layer)
+ self.layer_sequence_ordering.append('save_hidden_feature')
+ input_block_channels = [current_channel]
+
+ for level_idx, (mult, sdim) in enumerate(zip(channel_mult, second_dim)):
+ for _ in range(self.num_noattn_blocks[level_idx]):
+ layer = FCBlockDefault(
+ current_channel,
+ time_embed_dim,
+ out_channels = [mult*model_channels, sdim, 1],)
+
+ self.add_data_layer(layer)
+ current_channel = [mult*model_channels, sdim, 1]
+
+ if with_attn[level_idx]:
+ d_head, n_heads = self.get_d_head_n_heads(current_channel[0])
+ layer = CrossAttnDefault(
+ in_channels=current_channel[0],
+ d_head=d_head, n_heads=n_heads,)
+ self.add_context_layer(layer)
+
+ input_block_channels.append(current_channel)
+ self.layer_sequence_ordering.append('save_hidden_feature')
+
+ if level_idx != len(channel_mult) - 1:
+ layer = Linear_MultiDim(current_channel, current_channel, bias=True,)
+ self.add_data_layer(layer)
+ input_block_channels.append(current_channel)
+ self.layer_sequence_ordering.append('save_hidden_feature')
+
+ self.i_order = copy.deepcopy(self.layer_sequence_ordering)
+ self.layer_sequence_ordering = []
+
+ #################
+ # middle layers #
+ #################
+
+ self.add_data_layer(FCBlockDefault(current_channel, time_embed_dim, ))
+ d_head, n_heads = self.get_d_head_n_heads(current_channel[0])
+ self.add_context_layer(CrossAttnDefault(in_channels=current_channel[0], d_head=d_head, n_heads=n_heads))
+ self.add_data_layer(FCBlockDefault(current_channel, time_embed_dim, ))
+
+ self.m_order = copy.deepcopy(self.layer_sequence_ordering)
+ self.layer_sequence_ordering = []
+
+ #################
+ # output layers #
+ #################
+ for level_idx, (mult, sdim) in list(enumerate(zip(channel_mult, second_dim)))[::-1]:
+ for _ in range(self.num_noattn_blocks[level_idx] + 1):
+ self.layer_sequence_ordering.append('load_hidden_feature')
+ extra_channel = input_block_channels.pop()
+ layer = FCBlockDefault(
+ [current_channel[0] + extra_channel[0]] + current_channel[1:],
+ time_embed_dim,
+ out_channels = [mult*model_channels, sdim, 1], )
+
+ self.add_data_layer(layer)
+ current_channel = [mult*model_channels, sdim, 1]
+
+ if with_attn[level_idx]:
+ d_head, n_heads = self.get_d_head_n_heads(current_channel[0])
+ layer = CrossAttnDefault(
+ in_channels=current_channel[0], d_head=d_head, n_heads=n_heads)
+ self.add_context_layer(layer)
+
+ if level_idx != 0:
+ layer = Linear_MultiDim(current_channel, current_channel, bias=True, )
+ self.add_data_layer(layer)
+
+ layer = nn.Sequential(
+ normalization(current_channel[0]),
+ nn.SiLU(),
+ zero_module(Linear_MultiDim(current_channel, [output_channels], bias=True, )),
+ )
+ self.add_data_layer(layer)
+
+ self.o_order = copy.deepcopy(self.layer_sequence_ordering)
+ self.layer_order = copy.deepcopy(self.i_order + self.m_order + self.o_order)
+ del self.layer_sequence_ordering
+
+ self.parameter_group = {}
+ if self.glayer_included:
+ self.parameter_group['global'] = self.time_embed
+ if self.dlayer_included:
+ self.parameter_group['data'] = self.data_blocks
+ if self.clayer_included:
+ self.parameter_group['context'] = self.context_blocks
diff --git a/lib/model_zoo/pfd.py b/lib/model_zoo/pfd.py
new file mode 100644
index 0000000000000000000000000000000000000000..19cf31c329016d37edf69d9fa3d6d09bdf9aa1b2
--- /dev/null
+++ b/lib/model_zoo/pfd.py
@@ -0,0 +1,528 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import numpy.random as npr
+import copy
+from functools import partial
+from contextlib import contextmanager
+from lib.model_zoo.common.get_model import get_model, register
+from lib.log_service import print_log
+
+symbol = 'pfd'
+
+from .diffusion_utils import \
+ count_params, extract_into_tensor, make_beta_schedule
+from .distributions import normal_kl, DiagonalGaussianDistribution
+
+from .autokl import AutoencoderKL
+from .ema import LitEma
+
+def highlight_print(info):
+ print_log('')
+ print_log(''.join(['#']*(len(info)+4)))
+ print_log('# '+info+' #')
+ print_log(''.join(['#']*(len(info)+4)))
+ print_log('')
+
+@register('pfd')
+class PromptFreeDiffusion(nn.Module):
+ def __init__(self,
+ vae_cfg_list,
+ ctx_cfg_list,
+ diffuser_cfg_list,
+ global_layer_ptr=None,
+
+ parameterization="eps",
+ timesteps=1000,
+ use_ema=False,
+
+ beta_schedule="linear",
+ beta_linear_start=1e-4,
+ beta_linear_end=2e-2,
+ given_betas=None,
+ cosine_s=8e-3,
+
+ loss_type="l2",
+ l_simple_weight=1.,
+ l_elbo_weight=0.,
+
+ v_posterior=0.,
+ learn_logvar=False,
+ logvar_init=0,
+
+ latent_scale_factor=None,):
+
+ super().__init__()
+ assert parameterization in ["eps", "x0"], \
+ 'currently only supporting "eps" and "x0"'
+ self.parameterization = parameterization
+ highlight_print("Running in {} mode".format(self.parameterization))
+
+ self.vae = self.get_model_list(vae_cfg_list)
+ self.ctx = self.get_model_list(ctx_cfg_list)
+ self.diffuser = self.get_model_list(diffuser_cfg_list)
+ self.global_layer_ptr = global_layer_ptr
+
+ assert self.check_diffuser(), 'diffuser layers are not aligned!'
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print_log(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.loss_type = loss_type
+ self.l_simple_weight = l_simple_weight
+ self.l_elbo_weight = l_elbo_weight
+ self.v_posterior = v_posterior
+
+ self.register_schedule(
+ given_betas=given_betas,
+ beta_schedule=beta_schedule,
+ timesteps=timesteps,
+ linear_start=beta_linear_start,
+ linear_end=beta_linear_end,
+ cosine_s=cosine_s)
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ self.latent_scale_factor = {} if latent_scale_factor is None else latent_scale_factor
+
+ self.parameter_group = {}
+ for namei, diffuseri in self.diffuser.items():
+ self.parameter_group.update({
+ 'diffuser_{}_{}'.format(namei, pgni):pgi for pgni, pgi in diffuseri.parameter_group.items()
+ })
+
+ def to(self, device):
+ self.device = device
+ super().to(device)
+
+ def get_model_list(self, cfg_list):
+ net = nn.ModuleDict()
+ for name, cfg in cfg_list:
+ net[name] = get_model()(cfg)
+ return net
+
+ def register_schedule(self,
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3):
+ if given_betas is not None:
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, \
+ 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print_log(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print_log(f"{context}: Restored training weights")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ value1 = extract_into_tensor(
+ self.sqrt_recip_alphas_cumprod, t, x_t.shape)
+ value2 = extract_into_tensor(
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ return value1*x_t -value2*noise
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = torch.randn_like(x_start) if noise is None else noise
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def forward(self, x_info, c_info):
+ x = x_info['x']
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x_info, t, c_info)
+
+ def p_losses(self, x_info, t, c_info, noise=None):
+ x = x_info['x']
+ noise = torch.randn_like(x) if noise is None else noise
+ x_noisy = self.q_sample(x_start=x, t=t, noise=noise)
+ x_info['x'] = x_noisy
+ model_output = self.apply_model(x_info, t, c_info)
+
+ loss_dict = {}
+
+ if self.parameterization == "x0":
+ target = x
+ elif self.parameterization == "eps":
+ target = noise
+ else:
+ raise NotImplementedError()
+
+ bs = model_output.shape[0]
+ loss_simple = self.get_loss(model_output, target, mean=False).view(bs, -1).mean(-1)
+ loss_dict['loss_simple'] = loss_simple.mean()
+
+ # logvar_t = self.logvar[t].to(self.device)
+ logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+
+ if self.learn_logvar:
+ loss_dict['loss_gamma'] = loss.mean()
+ loss_dict['logvar' ] = self.logvar.data.mean()
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).view(bs, -1).mean(-1)
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict['loss_vlb'] = loss_vlb
+ loss_dict.update({'Loss': loss})
+
+ return loss, loss_dict
+
+ @torch.no_grad()
+ def vae_encode(self, x, which, **kwargs):
+ z = self.vae[which].encode(x, **kwargs)
+ if self.latent_scale_factor is not None:
+ if self.latent_scale_factor.get(which, None) is not None:
+ scale = self.latent_scale_factor[which]
+ return scale * z
+ return z
+
+ @torch.no_grad()
+ def vae_decode(self, z, which, **kwargs):
+ if self.latent_scale_factor is not None:
+ if self.latent_scale_factor.get(which, None) is not None:
+ scale = self.latent_scale_factor[which]
+ z = 1./scale * z
+ x = self.vae[which].decode(z, **kwargs)
+ return x
+
+ @torch.no_grad()
+ def ctx_encode(self, x, which, **kwargs):
+ if which.find('vae_') == 0:
+ return self.vae[which[4:]].encode(x, **kwargs)
+ else:
+ return self.ctx[which].encode(x, **kwargs)
+
+ def ctx_encode_trainable(self, x, which, **kwargs):
+ if which.find('vae_') == 0:
+ return self.vae[which[4:]].encode(x, **kwargs)
+ else:
+ return self.ctx[which].encode(x, **kwargs)
+
+ def check_diffuser(self):
+ for idx, (_, diffuseri) in enumerate(self.diffuser.items()):
+ if idx==0:
+ order = diffuseri.layer_order
+ else:
+ if not order == diffuseri.layer_order:
+ return False
+ return True
+
+ @torch.no_grad()
+ def on_train_batch_start(self, x):
+ pass
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def apply_model(self, x_info, timesteps, c_info):
+ x_type, x = x_info['type'], x_info['x']
+ c_type, c = c_info['type'], c_info['c']
+ dtype = x.dtype
+
+ hs = []
+
+ from .openaimodel import timestep_embedding
+
+ glayer_ptr = x_type if self.global_layer_ptr is None else self.global_layer_ptr
+ model_channels = self.diffuser[glayer_ptr].model_channels
+ t_emb = timestep_embedding(timesteps, model_channels, repeat_only=False).to(dtype)
+ emb = self.diffuser[glayer_ptr].time_embed(t_emb)
+
+ d_iter = iter(self.diffuser[x_type].data_blocks)
+ c_iter = iter(self.diffuser[c_type].context_blocks)
+
+ i_order = self.diffuser[x_type].i_order
+ m_order = self.diffuser[x_type].m_order
+ o_order = self.diffuser[x_type].o_order
+
+ h = x
+ for ltype in i_order:
+ if ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, c)
+ elif ltype == 'save_hidden_feature':
+ hs.append(h)
+
+ for ltype in m_order:
+ if ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, c)
+
+ for ltype in o_order:
+ if ltype == 'load_hidden_feature':
+ h = torch.cat([h, hs.pop()], dim=1)
+ elif ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, c)
+ o = h
+
+ return o
+
+ def context_mixing(self, x, emb, context_module_list, context_info_list, mixing_type):
+ nm = len(context_module_list)
+ nc = len(context_info_list)
+ assert nm == nc
+ context = [c_info['c'] for c_info in context_info_list]
+ cratio = np.array([c_info['ratio'] for c_info in context_info_list])
+ cratio = cratio / cratio.sum()
+
+ if mixing_type == 'attention':
+ h = None
+ for module, c, r in zip(context_module_list, context, cratio):
+ hi = module(x, emb, c) * r
+ h = h+hi if h is not None else hi
+ return h
+ elif mixing_type == 'layer':
+ ni = npr.choice(nm, p=cratio)
+ module = context_module_list[ni]
+ c = context[ni]
+ h = module(x, emb, c)
+ return h
+
+ def apply_model_multicontext(self, x_info, timesteps, c_info_list, mixing_type='attention'):
+ '''
+ context_info_list: [[context_type, context, ratio]] for 'attention'
+ '''
+
+ x_type, x = x_info['type'], x_info['x']
+ dtype = x.dtype
+
+ hs = []
+
+ from .openaimodel import timestep_embedding
+ model_channels = self.diffuser[x_type].model_channels
+ t_emb = timestep_embedding(timesteps, model_channels, repeat_only=False).to(dtype)
+ emb = self.diffuser[x_type].time_embed(t_emb)
+
+ d_iter = iter(self.diffuser[x_type].data_blocks)
+ c_iter_list = [iter(self.diffuser[c_info['type']].context_blocks) for c_info in c_info_list]
+
+ i_order = self.diffuser[x_type].i_order
+ m_order = self.diffuser[x_type].m_order
+ o_order = self.diffuser[x_type].o_order
+
+ h = x
+ for ltype in i_order:
+ if ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module_list = [next(c_iteri) for c_iteri in c_iter_list]
+ h = self.context_mixing(h, emb, module_list, c_info_list, mixing_type)
+ elif ltype == 'save_hidden_feature':
+ hs.append(h)
+
+ for ltype in m_order:
+ if ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module_list = [next(c_iteri) for c_iteri in c_iter_list]
+ h = self.context_mixing(h, emb, module_list, c_info_list, mixing_type)
+
+ for ltype in o_order:
+ if ltype == 'load_hidden_feature':
+ h = torch.cat([h, hs.pop()], dim=1)
+ elif ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module_list = [next(c_iteri) for c_iteri in c_iter_list]
+ h = self.context_mixing(h, emb, module_list, c_info_list, mixing_type)
+ o = h
+ return o
+
+ def get_device(self):
+ one_param = next(self.parameters())
+ return one_param.device
+
+ def get_dtype(self):
+ one_param = next(self.parameters())
+ return one_param.dtype
+
+ @torch.no_grad()
+ def print_debug_checksum(self):
+ csum = {
+ ki : next(self.parameter_group[ki][0].parameters()).abs().sum().item()
+ for ki in self.parameter_group.keys()
+ }
+ print(csum)
+
+@register('pfd_with_control')
+class PromptFreeDiffusion_with_control(PromptFreeDiffusion):
+ def __init__(self, *args, **kwargs):
+ ctl_cfg = kwargs.pop('ctl_cfg')
+ super().__init__(*args, **kwargs)
+ self.ctl = get_model()(ctl_cfg)
+ self.control_scales = [1.0] * 13
+ self.parameter_group['ctl'] = [self.ctl]
+
+ def apply_model(self, x_info, timesteps, c_info):
+ x_type, x = x_info['type'], x_info['x']
+ c_type, c = c_info['type'], c_info['c']
+ cc = c_info.get('control', None)
+ dtype = x.dtype
+
+ if cc is not None:
+ ccs = self.ctl(x, hint=cc, timesteps=timesteps, context=c)
+ else:
+ class ccs_zeros(object):
+ def __init__(self): pass
+ def pop(self): return 0
+ ccs = ccs_zeros()
+
+ hs = []
+
+ from .openaimodel import timestep_embedding
+
+ glayer_ptr = x_type if self.global_layer_ptr is None else self.global_layer_ptr
+ model_channels = self.diffuser[glayer_ptr].model_channels
+ t_emb = timestep_embedding(timesteps, model_channels, repeat_only=False).to(dtype)
+ emb = self.diffuser[glayer_ptr].time_embed(t_emb)
+
+ d_iter = iter(self.diffuser[x_type].data_blocks)
+ c_iter = iter(self.diffuser[c_type].context_blocks)
+
+ i_order = self.diffuser[x_type].i_order
+ m_order = self.diffuser[x_type].m_order
+ o_order = self.diffuser[x_type].o_order
+
+ h = x
+ for ltype in i_order:
+ if ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, c)
+ elif ltype == 'save_hidden_feature':
+ hs.append(h)
+
+ for ltype in m_order:
+ if ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, c)
+
+ h = h + ccs.pop()
+
+ for ltype in o_order:
+ if ltype == 'load_hidden_feature':
+ h = torch.cat([h, hs.pop()+ccs.pop()], dim=1)
+ elif ltype == 'd':
+ module = next(d_iter)
+ h = module(h, emb, None)
+ elif ltype == 'c':
+ module = next(c_iter)
+ h = module(h, emb, c)
+ o = h
+
+ return o
diff --git a/lib/model_zoo/sampler.py b/lib/model_zoo/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2586d7eba4c10720041a478bd9e72b13b78f28e7
--- /dev/null
+++ b/lib/model_zoo/sampler.py
@@ -0,0 +1,104 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+
+def append_dims(x, target_dims):
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+def default_noise_sampler(x):
+ return lambda sigma, sigma_next: torch.randn_like(x)
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.):
+ if not eta:
+ return sigma_to, 0.
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
+ return sigma_down, sigma_up
+
+def to_d(x, sigma, denoised):
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+class Sampler(object):
+ def __init__(self, net, type="ddim", steps=50, output_dim=[512, 512], n_samples=4, scale=7.5):
+ super().__init__()
+ self.net = net
+ self.type = type
+ self.steps = steps
+ self.output_dim = output_dim
+ self.n_samples = n_samples
+ self.scale = scale
+ self.sigmas = ((1 - net.alphas_cumprod) / net.alphas_cumprod) ** 0.5
+ self.log_sigmas = self.sigmas.log()
+
+ def t_to_sigma(self, t):
+ t = t.float()
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
+ return log_sigma.exp()
+
+ def get_sigmas(self, n=None):
+ def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+ if n is None:
+ return append_zero(self.sigmas.flip(0))
+ t_max = len(self.sigmas) - 1
+ t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
+ return append_zero(self.t_to_sigma(t))
+
+ @torch.no_grad()
+ def sample(self, x_info, c_info):
+ h, w = self.output_dim
+ shape = [self.n_samples, 4, h//8, w//8]
+ device, dtype = self.net.get_device(), self.net.get_dtype()
+
+ if ('xt' in x_info) and (x_info['xt'] is not None):
+ xt = x_info['xt'].astype(dtype).to(device)
+ x_info['x'] = xt
+ elif ('x0' in x_info) and (x_info['x0'] is not None):
+ x0 = x_info['x0'].type(dtype).to(device)
+ ts = timesteps[x_info['x0_forward_timesteps']].repeat(self.n_samples)
+ ts = torch.Tensor(ts).long().to(device)
+ timesteps = timesteps[:x_info['x0_forward_timesteps']]
+ x0_nz = self.model.q_sample(x0, ts)
+ x_info['x'] = x0_nz
+ else:
+ x_info['x'] = torch.randn(shape, device=device, dtype=dtype)
+
+ sigmas = self.get_sigmas(n=self.steps)
+
+ if self.type == 'eular_a':
+ rv = self.sample_euler_ancestral(
+ x_info=x_info,
+ c_info=c_info,
+ sigmas = sigmas)
+ return rv
+
+ @torch.no_grad()
+ def sample_euler_ancestral(
+ self, x_info, c_info, sigmas, eta=1., s_noise=1.,):
+
+ x = x_info['x']
+ x = x * sigmas[0]
+
+ noise_sampler = default_noise_sampler(x)
+
+ s_in = x.new_ones([x.shape[0]])
+ for i in range(len(sigmas)-1):
+ denoised = self.net.apply_model(x, sigmas[i] * s_in, )
+
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
+ d = to_d(x, sigmas[i], denoised)
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ if sigmas[i + 1] > 0:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+ return x
diff --git a/lib/model_zoo/seecoder.py b/lib/model_zoo/seecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad331180113db5ee33186a5abce81e871e0c7c9
--- /dev/null
+++ b/lib/model_zoo/seecoder.py
@@ -0,0 +1,576 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+
+from .seecoder_utils import with_pos_embed
+from lib.model_zoo.common.get_model import get_model, register
+
+symbol = 'seecoder'
+
+###########
+# helpers #
+###########
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+def c2_xavier_fill(module):
+ # Caffe2 implementation of XavierFill in fact
+ nn.init.kaiming_uniform_(module.weight, a=1)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+def with_pos_embed(x, pos):
+ return x if pos is None else x + pos
+
+###########
+# Modules #
+###########
+
+class Conv2d_Convenience(nn.Conv2d):
+ def __init__(self, *args, **kwargs):
+ norm = kwargs.pop("norm", None)
+ activation = kwargs.pop("activation", None)
+ super().__init__(*args, **kwargs)
+ self.norm = norm
+ self.activation = activation
+
+ def forward(self, x):
+ x = F.conv2d(
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
+
+class DecoderLayer(nn.Module):
+ def __init__(self,
+ dim=256,
+ feedforward_dim=1024,
+ dropout=0.1,
+ activation="relu",
+ n_heads=8,):
+
+ super().__init__()
+
+ self.self_attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(dim)
+
+ self.linear1 = nn.Linear(dim, feedforward_dim)
+ self.activation = _get_activation_fn(activation)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(feedforward_dim, dim)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(dim)
+
+ def forward(self, x):
+ h = x
+ h1 = self.self_attn(x, x, x, attn_mask=None)[0]
+ h = h + self.dropout1(h1)
+ h = self.norm1(h)
+
+ h2 = self.linear2(self.dropout2(self.activation(self.linear1(h))))
+ h = h + self.dropout3(h2)
+ h = self.norm2(h)
+ return h
+
+class DecoderLayerStacked(nn.Module):
+ def __init__(self, layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self, x):
+ h = x
+ for _, layer in enumerate(self.layers):
+ h = layer(h)
+ if self.norm is not None:
+ h = self.norm(h)
+ return h
+
+class SelfAttentionLayer(nn.Module):
+ def __init__(self, channels, nhead, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(channels, nhead, dropout=dropout)
+
+ self.norm = nn.LayerNorm(channels)
+ self.dropout = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward_post(self,
+ qkv,
+ qk_pos = None,
+ mask = None,):
+ h = qkv
+ qk = with_pos_embed(qkv, qk_pos).transpose(0, 1)
+ v = qkv.transpose(0, 1)
+ h1 = self.self_attn(qk, qk, v, attn_mask=mask)[0]
+ h1 = h1.transpose(0, 1)
+ h = h + self.dropout(h1)
+ h = self.norm(h)
+ return h
+
+ def forward_pre(self, tgt,
+ tgt_mask = None,
+ tgt_key_padding_mask = None,
+ query_pos = None):
+ # deprecated
+ assert False
+ tgt2 = self.norm(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward(self, *args, **kwargs):
+ if self.normalize_before:
+ return self.forward_pre(*args, **kwargs)
+ return self.forward_post(*args, **kwargs)
+
+class CrossAttentionLayer(nn.Module):
+ def __init__(self, channels, nhead, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.multihead_attn = nn.MultiheadAttention(channels, nhead, dropout=dropout)
+
+ self.norm = nn.LayerNorm(channels)
+ self.dropout = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward_post(self,
+ q,
+ kv,
+ q_pos = None,
+ k_pos = None,
+ mask = None,):
+ h = q
+ q = with_pos_embed(q, q_pos).transpose(0, 1)
+ k = with_pos_embed(kv, k_pos).transpose(0, 1)
+ v = kv.transpose(0, 1)
+ h1 = self.multihead_attn(q, k, v, attn_mask=mask)[0]
+ h1 = h1.transpose(0, 1)
+ h = h + self.dropout(h1)
+ h = self.norm(h)
+ return h
+
+ def forward_pre(self, tgt, memory,
+ memory_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None):
+ # Deprecated
+ assert False
+ tgt2 = self.norm(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward(self, *args, **kwargs):
+ if self.normalize_before:
+ return self.forward_pre(*args, **kwargs)
+ return self.forward_post(*args, **kwargs)
+
+class FeedForwardLayer(nn.Module):
+ def __init__(self, channels, hidden_channels=2048, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.linear1 = nn.Linear(channels, hidden_channels)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(hidden_channels, channels)
+ self.norm = nn.LayerNorm(channels)
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward_post(self, x):
+ h = x
+ h1 = self.linear2(self.dropout(self.activation(self.linear1(h))))
+ h = h + self.dropout(h1)
+ h = self.norm(h)
+ return h
+
+ def forward_pre(self, x):
+ xn = self.norm(x)
+ h = x
+ h1 = self.linear2(self.dropout(self.activation(self.linear1(xn))))
+ h = h + self.dropout(h1)
+ return h
+
+ def forward(self, *args, **kwargs):
+ if self.normalize_before:
+ return self.forward_pre(*args, **kwargs)
+ return self.forward_post(*args, **kwargs)
+
+class MLP(nn.Module):
+ def __init__(self, in_channels, channels, out_channels, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [channels] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k)
+ for n, k in zip([in_channels]+h, h+[out_channels]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+class PPE_MLP(nn.Module):
+ def __init__(self, freq_num=20, freq_max=None, out_channel=768, mlp_layer=3):
+ import math
+ super().__init__()
+ self.freq_num = freq_num
+ self.freq_max = freq_max
+ self.out_channel = out_channel
+ self.mlp_layer = mlp_layer
+ self.twopi = 2 * math.pi
+
+ mlp = []
+ in_channel = freq_num*4
+ for idx in range(mlp_layer):
+ linear = nn.Linear(in_channel, out_channel, bias=True)
+ nn.init.xavier_normal_(linear.weight)
+ nn.init.constant_(linear.bias, 0)
+ mlp.append(linear)
+ if idx != mlp_layer-1:
+ mlp.append(nn.SiLU())
+ in_channel = out_channel
+ self.mlp = nn.Sequential(*mlp)
+ nn.init.constant_(self.mlp[-1].weight, 0)
+
+ def forward(self, x, mask=None):
+ assert mask is None, "Mask not implemented"
+ h, w = x.shape[-2:]
+ minlen = min(h, w)
+
+ h_embed, w_embed = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
+ if self.training:
+ import numpy.random as npr
+ pertube_h, pertube_w = npr.uniform(-0.5, 0.5), npr.uniform(-0.5, 0.5)
+ else:
+ pertube_h, pertube_w = 0, 0
+
+ h_embed = (h_embed+0.5 - h/2 + pertube_h) / (minlen) * self.twopi
+ w_embed = (w_embed+0.5 - w/2 + pertube_w) / (minlen) * self.twopi
+ h_embed, w_embed = h_embed.to(x.device).to(x.dtype), w_embed.to(x.device).to(x.dtype)
+
+ dim_t = torch.linspace(0, 1, self.freq_num, dtype=torch.float32, device=x.device)
+ freq_max = self.freq_max if self.freq_max is not None else minlen/2
+ dim_t = freq_max ** dim_t.to(x.dtype)
+
+ pos_h = h_embed[:, :, None] * dim_t
+ pos_w = w_embed[:, :, None] * dim_t
+ pos = torch.cat((pos_h.sin(), pos_h.cos(), pos_w.sin(), pos_w.cos()), dim=-1)
+ pos = self.mlp(pos)
+ pos = pos.permute(2, 0, 1)[None]
+ return pos
+
+ def __repr__(self, _repr_indent=4):
+ head = "Positional encoding " + self.__class__.__name__
+ body = [
+ "num_pos_feats: {}".format(self.num_pos_feats),
+ "temperature: {}".format(self.temperature),
+ "normalize: {}".format(self.normalize),
+ "scale: {}".format(self.scale),
+ ]
+ # _repr_indent = 4
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
+
+###########
+# Decoder #
+###########
+
+@register('seecoder_decoder')
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ inchannels,
+ trans_input_tags,
+ trans_num_layers,
+ trans_dim,
+ trans_nheads,
+ trans_dropout,
+ trans_feedforward_dim,):
+
+ super().__init__()
+ trans_inchannels = {
+ k: v for k, v in inchannels.items() if k in trans_input_tags}
+ fpn_inchannels = {
+ k: v for k, v in inchannels.items() if k not in trans_input_tags}
+
+ self.trans_tags = sorted(list(trans_inchannels.keys()))
+ self.fpn_tags = sorted(list(fpn_inchannels.keys()))
+ self.all_tags = sorted(list(inchannels.keys()))
+
+ if len(self.trans_tags)==0:
+ assert False # Not allowed
+
+ self.num_trans_lvls = len(self.trans_tags)
+
+ self.inproj_layers = nn.ModuleDict()
+ for tagi in self.trans_tags:
+ layeri = nn.Sequential(
+ nn.Conv2d(trans_inchannels[tagi], trans_dim, kernel_size=1),
+ nn.GroupNorm(32, trans_dim),)
+ nn.init.xavier_uniform_(layeri[0].weight, gain=1)
+ nn.init.constant_(layeri[0].bias, 0)
+ self.inproj_layers[tagi] = layeri
+
+ tlayer = DecoderLayer(
+ dim = trans_dim,
+ n_heads = trans_nheads,
+ dropout = trans_dropout,
+ feedforward_dim = trans_feedforward_dim,
+ activation = 'relu',)
+
+ self.transformer = DecoderLayerStacked(tlayer, trans_num_layers)
+ for p in self.transformer.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ self.level_embed = nn.Parameter(torch.Tensor(len(self.trans_tags), trans_dim))
+ nn.init.normal_(self.level_embed)
+
+ self.lateral_layers = nn.ModuleDict()
+ self.output_layers = nn.ModuleDict()
+ for tagi in self.all_tags:
+ lateral_conv = Conv2d_Convenience(
+ inchannels[tagi], trans_dim, kernel_size=1,
+ bias=False, norm=nn.GroupNorm(32, trans_dim))
+ c2_xavier_fill(lateral_conv)
+ self.lateral_layers[tagi] = lateral_conv
+
+ for tagi in self.fpn_tags:
+ output_conv = Conv2d_Convenience(
+ trans_dim, trans_dim, kernel_size=3, stride=1, padding=1,
+ bias=False, norm=nn.GroupNorm(32, trans_dim), activation=F.relu,)
+ c2_xavier_fill(output_conv)
+ self.output_layers[tagi] = output_conv
+
+ def forward(self, features):
+ x = []
+ spatial_shapes = {}
+ for idx, tagi in enumerate(self.trans_tags[::-1]):
+ xi = features[tagi]
+ xi = self.inproj_layers[tagi](xi)
+ bs, _, h, w = xi.shape
+ spatial_shapes[tagi] = (h, w)
+ xi = xi.flatten(2).transpose(1, 2) + self.level_embed[idx].view(1, 1, -1)
+ x.append(xi)
+
+ x_length = [xi.shape[1] for xi in x]
+ x_concat = torch.cat(x, 1)
+ y_concat = self.transformer(x_concat)
+ y = torch.split(y_concat, x_length, dim=1)
+
+ out = {}
+ for idx, tagi in enumerate(self.trans_tags[::-1]):
+ h, w = spatial_shapes[tagi]
+ yi = y[idx].transpose(1, 2).view(bs, -1, h, w)
+ out[tagi] = yi
+
+ for idx, tagi in enumerate(self.all_tags[::-1]):
+ lconv = self.lateral_layers[tagi]
+ if tagi in self.trans_tags:
+ out[tagi] = out[tagi] + lconv(features[tagi])
+ tag_save = tagi
+ else:
+ oconv = self.output_layers[tagi]
+ h = lconv(features[tagi])
+ oprev = out[tag_save]
+ h = h + F.interpolate(oconv(oprev), size=h.shape[-2:], mode="bilinear", align_corners=False)
+ out[tagi] = h
+
+ return out
+
+#####################
+# Query Transformer #
+#####################
+
+@register('seecoder_query_transformer')
+class QueryTransformer(nn.Module):
+ def __init__(self,
+ in_channels,
+ hidden_dim,
+ num_queries = [8, 144],
+ nheads = 8,
+ num_layers = 9,
+ feedforward_dim = 2048,
+ mask_dim = 256,
+ pre_norm = False,
+ num_feature_levels = 3,
+ enforce_input_project = False,
+ with_fea2d_pos = True):
+
+ super().__init__()
+
+ if with_fea2d_pos:
+ self.pe_layer = PPE_MLP(freq_num=20, freq_max=None, out_channel=hidden_dim, mlp_layer=3)
+ else:
+ self.pe_layer = None
+
+ if in_channels!=hidden_dim or enforce_input_project:
+ self.input_proj = nn.ModuleList()
+ for _ in range(num_feature_levels):
+ self.input_proj.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1))
+ c2_xavier_fill(self.input_proj[-1])
+ else:
+ self.input_proj = None
+
+ self.num_heads = nheads
+ self.num_layers = num_layers
+ self.transformer_selfatt_layers = nn.ModuleList()
+ self.transformer_crossatt_layers = nn.ModuleList()
+ self.transformer_feedforward_layers = nn.ModuleList()
+
+ for _ in range(self.num_layers):
+ self.transformer_selfatt_layers.append(
+ SelfAttentionLayer(
+ channels=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm, ))
+
+ self.transformer_crossatt_layers.append(
+ CrossAttentionLayer(
+ channels=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm, ))
+
+ self.transformer_feedforward_layers.append(
+ FeedForwardLayer(
+ channels=hidden_dim,
+ hidden_channels=feedforward_dim,
+ dropout=0.0,
+ normalize_before=pre_norm, ))
+
+ self.num_queries = num_queries
+ num_gq, num_lq = self.num_queries
+ self.init_query = nn.Embedding(num_gq+num_lq, hidden_dim)
+ self.query_pos_embedding = nn.Embedding(num_gq+num_lq, hidden_dim)
+
+ self.num_feature_levels = num_feature_levels
+ self.level_embed = nn.Embedding(num_feature_levels, hidden_dim)
+
+ def forward(self, x):
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels
+ fea2d = []
+ fea2d_pos = []
+ size_list = []
+
+ for i in range(self.num_feature_levels):
+ size_list.append(x[i].shape[-2:])
+ if self.pe_layer is not None:
+ pi = self.pe_layer(x[i], None).flatten(2)
+ pi = pi.transpose(1, 2)
+ else:
+ pi = None
+ xi = self.input_proj[i](x[i]) if self.input_proj is not None else x[i]
+ xi = xi.flatten(2) + self.level_embed.weight[i][None, :, None]
+ xi = xi.transpose(1, 2)
+ fea2d.append(xi)
+ fea2d_pos.append(pi)
+
+ bs, _, _ = fea2d[0].shape
+ num_gq, num_lq = self.num_queries
+ gquery = self.init_query.weight[:num_gq].unsqueeze(0).repeat(bs, 1, 1)
+ lquery = self.init_query.weight[num_gq:].unsqueeze(0).repeat(bs, 1, 1)
+
+ gquery_pos = self.query_pos_embedding.weight[:num_gq].unsqueeze(0).repeat(bs, 1, 1)
+ lquery_pos = self.query_pos_embedding.weight[num_gq:].unsqueeze(0).repeat(bs, 1, 1)
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+
+ qout = self.transformer_crossatt_layers[i](
+ q = lquery,
+ kv = fea2d[level_index],
+ q_pos = lquery_pos,
+ k_pos = fea2d_pos[level_index],
+ mask = None,)
+ lquery = qout
+
+ qout = self.transformer_selfatt_layers[i](
+ qkv = torch.cat([gquery, lquery], dim=1),
+ qk_pos = torch.cat([gquery_pos, lquery_pos], dim=1),)
+
+ qout = self.transformer_feedforward_layers[i](qout)
+
+ gquery = qout[:, :num_gq]
+ lquery = qout[:, num_gq:]
+
+ output = torch.cat([gquery, lquery], dim=1)
+
+ return output
+
+##################
+# Main structure #
+##################
+
+@register('seecoder')
+class SemanticExtractionEncoder(nn.Module):
+ def __init__(self,
+ imencoder_cfg,
+ imdecoder_cfg,
+ qtransformer_cfg):
+ super().__init__()
+ self.imencoder = get_model()(imencoder_cfg)
+ self.imdecoder = get_model()(imdecoder_cfg)
+ self.qtransformer = get_model()(qtransformer_cfg)
+
+ def forward(self, x):
+ fea = self.imencoder(x)
+ hs = {'res3' : fea['res3'],
+ 'res4' : fea['res4'],
+ 'res5' : fea['res5'], }
+ hs = self.imdecoder(hs)
+ hs = [hs['res3'], hs['res4'], hs['res5']]
+ q = self.qtransformer(hs)
+ return q
+
+ def encode(self, x):
+ return self(x)
diff --git a/lib/model_zoo/seecoder_decoder.py b/lib/model_zoo/seecoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..49ba6f5fa550d183d2424dc39eeddc6a78d64f97
--- /dev/null
+++ b/lib/model_zoo/seecoder_decoder.py
@@ -0,0 +1,15 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+
+from lib.model_zoo.common.get_model import get_model, register
+
+from .seecoder_utils import PositionEmbeddingSine, _get_clones, \
+ _get_activation_fn, _is_power_of_2, c2_xavier_fill, Conv2d_Convenience
+
+###########
+# modules #
+###########
+
diff --git a/lib/model_zoo/seecoder_utils.py b/lib/model_zoo/seecoder_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..40acf4d8eb70cecc98aa46f999a6218c09b5bd0f
--- /dev/null
+++ b/lib/model_zoo/seecoder_utils.py
@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+import math
+import copy
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+ return (n & (n-1) == 0) and n != 0
+
+def c2_xavier_fill(module):
+ # Caffe2 implementation of XavierFill in fact
+ nn.init.kaiming_uniform_(module.weight, a=1)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+def with_pos_embed(x, pos):
+ return x if pos is None else x + pos
+
+class PositionEmbeddingSine(nn.Module):
+ def __init__(self, num_pos_feats=64, temperature=256, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ h, w = not_mask.shape[-2:]
+ minlen = min(h, w)
+ h_embed = not_mask.cumsum(1, dtype=torch.float32)
+ w_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ h_embed = (h_embed - h/2) / (minlen + eps) * self.scale
+ w_embed = (w_embed - w/2) / (minlen + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_w = w_embed[:, :, :, None] / dim_t
+ pos_h = h_embed[:, :, :, None] / dim_t
+ pos_w = torch.stack(
+ (pos_w[:, :, :, 0::2].sin(), pos_w[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_h = torch.stack(
+ (pos_h[:, :, :, 0::2].sin(), pos_h[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_h, pos_w), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self, _repr_indent=4):
+ head = "Positional encoding " + self.__class__.__name__
+ body = [
+ "num_pos_feats: {}".format(self.num_pos_feats),
+ "temperature: {}".format(self.temperature),
+ "normalize: {}".format(self.normalize),
+ "scale: {}".format(self.scale),
+ ]
+ # _repr_indent = 4
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
+
+class Conv2d_Convenience(nn.Conv2d):
+ def __init__(self, *args, **kwargs):
+ norm = kwargs.pop("norm", None)
+ activation = kwargs.pop("activation", None)
+ super().__init__(*args, **kwargs)
+ self.norm = norm
+ self.activation = activation
+
+ def forward(self, x):
+ if not torch.jit.is_scripting():
+ if x.numel() == 0 and self.training:
+ assert not isinstance(
+ self.norm, torch.nn.SyncBatchNorm
+ ), "SyncBatchNorm does not support empty inputs!"
+ x = F.conv2d(
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+ )
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
+
diff --git a/lib/model_zoo/seet_tdecoder.py b/lib/model_zoo/seet_tdecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c395bd144e2cd9d15e10c08f9f0dec076f724836
--- /dev/null
+++ b/lib/model_zoo/seet_tdecoder.py
@@ -0,0 +1,699 @@
+import fvcore.nn.weight_init as weight_init
+from typing import Optional
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .msdeformattn import PositionEmbeddingSine, _get_clones, _get_activation_fn
+from lib.model_zoo.common.get_model import get_model, register
+
+##########
+# helper #
+##########
+
+def with_pos_embed(x, pos):
+ return x if pos is None else x + pos
+
+##############
+# One Former #
+##############
+
+class Transformer(nn.Module):
+ def __init__(self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,):
+
+ super().__init__()
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,)
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, query_embed, pos_embed, task_token=None):
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ if mask is not None:
+ mask = mask.flatten(1)
+
+ if task_token is None:
+ tgt = torch.zeros_like(query_embed)
+ else:
+ tgt = task_token.repeat(query_embed.shape[0], 1, 1)
+
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # src = memory
+ hs = self.decoder(
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
+ )
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self, src, mask=None, src_key_padding_mask=None, pos=None,):
+ output = src
+ for layer in self.layers:
+ output = layer(
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
+ )
+ if self.norm is not None:
+ output = self.norm(output)
+ return output
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None,
+ pos=None,
+ query_pos=None,):
+
+ output = tgt
+ intermediate = []
+ for layer in self.layers:
+ output = layer(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos,
+ )
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False, ):
+
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, x, pos):
+ return x if pos is None else x + pos
+
+ def forward_post(
+ self,
+ src,
+ src_mask = None,
+ src_key_padding_mask = None,
+ pos = None,):
+
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(
+ self,
+ src,
+ src_mask = None,
+ src_key_padding_mask = None,
+ pos = None,):
+
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(
+ self,
+ src,
+ src_mask = None,
+ src_key_padding_mask = None,
+ pos = None,):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,):
+
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, x, pos):
+ return x if pos is None else x + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ tgt_mask = None,
+ memory_mask = None,
+ tgt_key_padding_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None,):
+
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(
+ self,
+ tgt,
+ memory,
+ tgt_mask = None,
+ memory_mask = None,
+ tgt_key_padding_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None,):
+
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask = None,
+ memory_mask = None,
+ tgt_key_padding_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None, ):
+
+ if self.normalize_before:
+ return self.forward_pre(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,)
+ return self.forward_post(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,)
+
+class SelfAttentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def with_pos_embed(self, tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt,
+ tgt_mask = None,
+ tgt_key_padding_mask = None,
+ query_pos = None):
+ q = k = self.with_pos_embed(tgt, query_pos).transpose(0 ,1)
+ tgt2 = self.self_attn(q, k, value=tgt.transpose(0 ,1), attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2.transpose(0 ,1))
+ tgt = self.norm(tgt)
+
+ return tgt
+
+ def forward_pre(self, tgt,
+ tgt_mask = None,
+ tgt_key_padding_mask = None,
+ query_pos = None):
+ tgt2 = self.norm(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+
+ return tgt
+
+ def forward(self, tgt,
+ tgt_mask = None,
+ tgt_key_padding_mask = None,
+ query_pos = None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, tgt_mask,
+ tgt_key_padding_mask, query_pos)
+ return self.forward_post(tgt, tgt_mask,
+ tgt_key_padding_mask, query_pos)
+
+class CrossAttentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+ self.norm = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def with_pos_embed(self, tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt, memory,
+ memory_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None):
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos).transpose(0, 1),
+ key=self.with_pos_embed(memory, pos).transpose(0, 1),
+ value=memory.transpose(0, 1), attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2.transpose(0, 1))
+ tgt = self.norm(tgt)
+
+ return tgt
+
+ def forward_pre(self, tgt, memory,
+ memory_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None):
+ tgt2 = self.norm(tgt)
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory, attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)[0]
+ tgt = tgt + self.dropout(tgt2)
+
+ return tgt
+
+ def forward(self, tgt, memory,
+ memory_mask = None,
+ memory_key_padding_mask = None,
+ pos = None,
+ query_pos = None):
+ if self.normalize_before:
+ return self.forward_pre(tgt, memory, memory_mask,
+ memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, memory_mask,
+ memory_key_padding_mask, pos, query_pos)
+
+class FFNLayer(nn.Module):
+
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
+ activation="relu", normalize_before=False):
+ super().__init__()
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm = nn.LayerNorm(d_model)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def with_pos_embed(self, tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self, tgt):
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout(tgt2)
+ tgt = self.norm(tgt)
+ return tgt
+
+ def forward_pre(self, tgt):
+ tgt2 = self.norm(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout(tgt2)
+ return tgt
+
+ def forward(self, tgt):
+ if self.normalize_before:
+ return self.forward_pre(tgt)
+ return self.forward_post(tgt)
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+@register('seet_oneformer_tdecoder')
+class Seet_OneFormer_TDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ mask_classification,
+ num_classes,
+ hidden_dim,
+ num_queries,
+ nheads,
+ dropout,
+ dim_feedforward,
+ enc_layers,
+ is_train,
+ dec_layers,
+ class_dec_layers,
+ pre_norm,
+ mask_dim,
+ enforce_input_project,
+ use_task_norm,):
+
+ super().__init__()
+
+ assert mask_classification, "Only support mask classification model"
+ self.mask_classification = mask_classification
+ self.is_train = is_train
+ self.use_task_norm = use_task_norm
+
+ # positional encoding
+ N_steps = hidden_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ self.class_transformer = Transformer(
+ d_model=hidden_dim,
+ dropout=dropout,
+ nhead=nheads,
+ dim_feedforward=dim_feedforward,
+ num_encoder_layers=enc_layers,
+ num_decoder_layers=class_dec_layers,
+ normalize_before=pre_norm,
+ return_intermediate_dec=False,
+ )
+
+ # define Transformer decoder here
+ self.num_heads = nheads
+ self.num_layers = dec_layers
+ self.transformer_self_attention_layers = nn.ModuleList()
+ self.transformer_cross_attention_layers = nn.ModuleList()
+ self.transformer_ffn_layers = nn.ModuleList()
+
+ for _ in range(self.num_layers):
+ self.transformer_self_attention_layers.append(
+ SelfAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_cross_attention_layers.append(
+ CrossAttentionLayer(
+ d_model=hidden_dim,
+ nhead=nheads,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.transformer_ffn_layers.append(
+ FFNLayer(
+ d_model=hidden_dim,
+ dim_feedforward=dim_feedforward,
+ dropout=0.0,
+ normalize_before=pre_norm,
+ )
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+ self.num_queries = num_queries
+ # learnable query p.e.
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+ # level embedding (we always use 3 scales)
+ self.num_feature_levels = 3
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+ self.input_proj = nn.ModuleList()
+ for _ in range(self.num_feature_levels):
+ if in_channels != hidden_dim or enforce_input_project:
+ self.input_proj.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1))
+ weight_init.c2_xavier_fill(self.input_proj[-1])
+ else:
+ self.input_proj.append(nn.Sequential())
+
+ self.class_input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
+ weight_init.c2_xavier_fill(self.class_input_proj)
+
+ # output FFNs
+ if self.mask_classification:
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+ def forward(self, x, mask_features, tasks):
+ # x is a list of multi-scale feature
+ assert len(x) == self.num_feature_levels
+ src = []
+ pos = []
+ size_list = []
+
+ for i in range(self.num_feature_levels):
+ size_list.append(x[i].shape[-2:])
+ pos.append(self.pe_layer(x[i], None).flatten(2))
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
+ pos[-1] = pos[-1].transpose(1, 2)
+ src[-1] = src[-1].transpose(1, 2)
+
+ bs, _, _ = src[0].shape
+
+ query_embed = self.query_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+
+ tasks = tasks.unsqueeze(0)
+ if self.use_task_norm:
+ tasks = self.decoder_norm(tasks)
+
+ feats = self.pe_layer(mask_features, None)
+
+ out_t, _ = self.class_transformer(
+ feats, None,
+ self.query_embed.weight[:-1],
+ self.class_input_proj(mask_features),
+ tasks if self.use_task_norm else None)
+ out_t = out_t[0]
+
+ out = torch.cat([out_t, tasks], dim=1)
+
+ output = out.clone()
+
+ predictions_class = []
+ predictions_mask = []
+
+ # prediction heads on learnable query features
+ outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(
+ output, mask_features, attn_mask_target_size=size_list[0])
+ predictions_class.append(outputs_class)
+ predictions_mask.append(outputs_mask)
+
+ for i in range(self.num_layers):
+ level_index = i % self.num_feature_levels
+ attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+
+ output = self.transformer_cross_attention_layers[i](
+ output, src[level_index],
+ memory_mask=attn_mask,
+ memory_key_padding_mask=None,
+ pos=pos[level_index], query_pos=query_embed, )
+
+ output = self.transformer_self_attention_layers[i](
+ output, tgt_mask=None,
+ tgt_key_padding_mask=None,
+ query_pos=query_embed, )
+
+ # FFN
+ output = self.transformer_ffn_layers[i](output)
+
+ outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(
+ output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
+ predictions_class.append(outputs_class)
+ predictions_mask.append(outputs_mask)
+
+ assert len(predictions_class) == self.num_layers + 1
+
+ out = {
+ 'pred_logits': predictions_class[-1],
+ 'pred_masks': predictions_mask[-1],}
+
+ return out
+
+ def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):
+ decoder_output = self.decoder_norm(output)
+ outputs_class = self.class_embed(decoder_output)
+ mask_embed = self.mask_embed(decoder_output)
+ outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+
+ attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
+ attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+ attn_mask = attn_mask.detach()
+
+ return outputs_class, outputs_mask, attn_mask
diff --git a/lib/model_zoo/swin.py b/lib/model_zoo/swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6191009f528911b2b9cb518550ec9c48204bdb6
--- /dev/null
+++ b/lib/model_zoo/swin.py
@@ -0,0 +1,659 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from lib.model_zoo.common.get_model import register
+
+
+##############################
+# timm.models.layers helpers #
+##############################
+
+def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
+
+def _ntuple(n):
+ def parse(x):
+ from itertools import repeat
+ import collections.abc
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+ return parse
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ import warnings
+ import math
+
+ def norm_cdf(x):
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+ tensor.erfinv_()
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ with torch.no_grad():
+ return _trunc_normal_(tensor, mean, std, a, b)
+
+#############
+# main swin #
+#############
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """ Forward function.
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ """ Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ """ Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """ Patch Merging Layer
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """ Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """ Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device, dtype=x.dtype) # 1 Hp Wp 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+@register('swin')
+class SwinTransformer(nn.Module):
+ """ Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ use_checkpoint=False):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+
+ outputs = {
+ 'res2' : outs[0],
+ 'res3' : outs[1],
+ 'res4' : outs[2],
+ 'res5' : outs[3],}
+ return outputs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
+ return self
diff --git a/lib/sync.py b/lib/sync.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb82e9d1663424f5777f3087e3b766fcc0120454
--- /dev/null
+++ b/lib/sync.py
@@ -0,0 +1,254 @@
+from multiprocessing import shared_memory
+# import multiprocessing
+# if hasattr(multiprocessing, "shared_memory"):
+# from multiprocessing import shared_memory
+# else:
+# # workaround for single gpu inference on colab
+# shared_memory = None
+
+import random
+import pickle
+import time
+import copy
+import torch
+import torch.distributed as dist
+from lib.cfg_holder import cfg_unique_holder as cfguh
+
+def singleton(class_):
+ instances = {}
+ def getinstance(*args, **kwargs):
+ if class_ not in instances:
+ instances[class_] = class_(*args, **kwargs)
+ return instances[class_]
+ return getinstance
+
+def is_ddp():
+ return dist.is_available() and dist.is_initialized()
+
+def get_rank(type='local'):
+ ddp = is_ddp()
+ global_rank = dist.get_rank() if ddp else 0
+ local_world_size = torch.cuda.device_count()
+ if type == 'global':
+ return global_rank
+ elif type == 'local':
+ return global_rank % local_world_size
+ elif type == 'node':
+ return global_rank // local_world_size
+ elif type == 'all':
+ return global_rank, \
+ global_rank % local_world_size, \
+ global_rank // local_world_size
+ else:
+ assert False, 'Unknown type'
+
+def get_world_size(type='local'):
+ ddp = is_ddp()
+ global_rank = dist.get_rank() if ddp else 0
+ global_world_size = dist.get_world_size() if ddp else 1
+ local_world_size = torch.cuda.device_count()
+ if type == 'global':
+ return global_world_size
+ elif type == 'local':
+ return local_world_size
+ elif type == 'node':
+ return global_world_size // local_world_size
+ elif type == 'all':
+ return global_world_size, local_world_size, \
+ global_world_size // local_world_size
+ else:
+ assert False, 'Unknown type'
+
+class barrier_lock(object):
+ def __init__(self, n):
+ self.n = n
+ id = int(random.random()*10000) + int(time.time())*10000
+ self.lock_shmname = 'barrier_lock_{}'.format(id)
+ lock_shm = shared_memory.SharedMemory(
+ name=self.lock_shmname, create=True, size=n)
+ for i in range(n):
+ lock_shm.buf[i] = 0
+ lock_shm.close()
+
+ def destroy(self):
+ try:
+ lock_shm = shared_memory.SharedMemory(
+ name=self.lock_shmname)
+ lock_shm.close()
+ lock_shm.unlink()
+ except:
+ return
+
+ def wait(self, k):
+ lock_shm = shared_memory.SharedMemory(
+ name=self.lock_shmname)
+ assert lock_shm.buf[k] == 0, 'Two waits on the same id is not allowed.'
+ lock_shm.buf[k] = 1
+ if k == 0:
+ while sum([lock_shm.buf[i]==0 for i in range(self.n)]) != 0:
+ pass
+ for i in range(self.n):
+ lock_shm.buf[i] = 0
+ return
+ else:
+ while lock_shm.buf[k] != 0:
+ pass
+
+class default_lock(object):
+ def __init__(self):
+ id = int(random.random()*10000) + int(time.time())*10000
+ self.lock_shmname = 'lock_{}'.format(id)
+ lock_shm = shared_memory.SharedMemory(
+ name=self.lock_shmname, create=True, size=2)
+ for i in range(2):
+ lock_shm.buf[i] = 0
+ lock_shm.close()
+
+ def destroy(self):
+ try:
+ lock_shm = shared_memory.SharedMemory(
+ name=self.lock_shmname)
+ lock_shm.close()
+ lock_shm.unlink()
+ except:
+ return
+
+ def lock(self, k):
+ lock_shm = shared_memory.SharedMemory(
+ name=self.lock_shmname)
+ while lock_shm.buf[0] == 1:
+ pass
+ lock_shm.buf[0] = 1
+ lock_shm.buf[1] = k
+
+ def unlock(self, k):
+ lock_shm = shared_memory.SharedMemory(
+ name=self.lock_shmname)
+ if lock_shm.buf[1] != k:
+ return
+ lock_shm.buf[0] = 0
+ return
+
+class nodewise_sync_global(object):
+ """
+ This is the global part of nodewise_sync that need to call at master process
+ before spawn.
+ """
+ def __init__(self):
+ self.local_world_size = get_world_size('local')
+ self.reg_lock = default_lock()
+ self.b_lock = barrier_lock(self.local_world_size)
+ id = int(random.random()*10000) + int(time.time())*10000
+ self.id_shmname = 'nodewise_sync_id_shm_{}'.format(id)
+
+ def destroy(self):
+ self.reg_lock.destroy()
+ self.b_lock.destroy()
+ try:
+ shm = shared_memory.SharedMemory(name=self.id_shmname)
+ shm.close()
+ shm.unlink()
+ except:
+ return
+
+@singleton
+class nodewise_sync(object):
+ """
+ A class that centralize nodewise sync activities.
+ The backend is multiprocess sharememory, not torch, as torch not support this.
+ """
+ def __init__(self):
+ pass
+
+ def copy_global(self, reference):
+ self.local_world_size = reference.local_world_size
+ self.b_lock = reference.b_lock
+ self.reg_lock = reference.reg_lock
+ self.id_shmname = reference.id_shmname
+ return self
+
+ def local_init(self):
+ self.ddp = is_ddp()
+ self.global_rank, self.local_rank, self.node_rank = get_rank('all')
+ self.global_world_size, self.local_world_size, self.nodes = get_world_size('all')
+ if self.local_rank == 0:
+ temp = int(random.random()*10000) + int(time.time())*10000
+ temp = pickle.dumps(temp)
+ shm = shared_memory.SharedMemory(
+ name=self.id_shmname, create=True, size=len(temp))
+ shm.close()
+ return self
+
+ def random_sync_id(self):
+ assert self.local_rank is not None, 'Not initialized!'
+ if self.local_rank == 0:
+ sync_id = int(random.random()*10000) + int(time.time())*10000
+ data = pickle.dumps(sync_id)
+ shm = shared_memory.SharedMemory(name=self.id_shmname)
+ shm.buf[0:len(data)] = data[0:len(data)]
+ self.barrier()
+ shm.close()
+ else:
+ self.barrier()
+ shm = shared_memory.SharedMemory(name=self.id_shmname)
+ sync_id = pickle.loads(shm.buf)
+ shm.close()
+ return sync_id
+
+ def barrier(self):
+ self.b_lock.wait(self.local_rank)
+
+ def lock(self):
+ self.reg_lock.lock(self.local_rank)
+
+ def unlock(self):
+ self.reg_lock.unlock(self.local_rank)
+
+ def broadcast_r0(self, data=None):
+ assert self.local_rank is not None, 'Not initialized!'
+ id = self.random_sync_id()
+ shmname = 'broadcast_r0_{}'.format(id)
+ if self.local_rank == 0:
+ assert data!=None, 'Rank 0 needs to input data!'
+ data = pickle.dumps(data)
+ datan = len(data)
+ load_info_shm = shared_memory.SharedMemory(
+ name=shmname, create=True, size=datan)
+ load_info_shm.buf[0:datan] = data[0:datan]
+ self.barrier()
+ self.barrier()
+ load_info_shm.close()
+ load_info_shm.unlink()
+ return None
+ else:
+ assert data==None, 'Rank other than 1 should input None as data!'
+ self.barrier()
+ shm = shared_memory.SharedMemory(name=shmname)
+ data = pickle.loads(shm.buf)
+ shm.close()
+ self.barrier()
+ return data
+
+ def destroy(self):
+ self.barrier.destroy()
+ try:
+ shm = shared_memory.SharedMemory(name=self.id_shmname)
+ shm.close()
+ shm.unlink()
+ except:
+ return
+
+# import contextlib
+
+# @contextlib.contextmanager
+# def weight_sync(module, sync):
+# assert isinstance(module, torch.nn.Module)
+# if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+# yield
+# else:
+# with module.no_sync():
+# yield
+
+# def weight_sync(net):
+# for parameters in net.parameters():
+# dist.all_reduce(parameters, dist.ReduceOp.AVG)
\ No newline at end of file
diff --git a/lib/utils.py b/lib/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd37cb9a7e670aefd46d473ea5b5958a827ca443
--- /dev/null
+++ b/lib/utils.py
@@ -0,0 +1,651 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+# cudnn.enabled = True
+# cudnn.benchmark = True
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+import os
+import os.path as osp
+import sys
+import numpy as np
+import random
+import pprint
+import timeit
+import time
+import copy
+import matplotlib.pyplot as plt
+
+from .cfg_holder import cfg_unique_holder as cfguh
+
+from .data_factory import \
+ get_dataset, collate, \
+ get_loader, \
+ get_transform, \
+ get_estimator, \
+ get_formatter, \
+ get_sampler
+
+from .model_zoo import \
+ get_model, get_optimizer, get_scheduler
+
+from .log_service import print_log, distributed_log_manager
+
+from .evaluator import get_evaluator
+from . import sync
+
+class train_stage(object):
+ """
+ This is a template for a train stage,
+ (can be either train or test or anything)
+ Usually, it takes RANK
+ one dataloader, one model, one optimizer, one scheduler.
+ But it is not limited to these parameters.
+ """
+ def __init__(self):
+ self.nested_eval_stage = None
+ self.rv_keep = None
+
+ def is_better(self, x):
+ return (self.rv_keep is None) or (x>self.rv_keep)
+
+ def set_model(self, net, mode):
+ if mode == 'train':
+ return net.train()
+ elif mode == 'eval':
+ return net.eval()
+ else:
+ raise ValueError
+
+ def __call__(self,
+ **paras):
+ cfg = cfguh().cfg
+ cfgt = cfg.train
+ logm = distributed_log_manager()
+ epochn, itern_local, itern, samplen = 0, 0, 0, 0
+
+ step_type = cfgt.get('step_type', 'iter')
+ assert step_type in ['epoch', 'iter', 'sample'], \
+ 'Step type must be in [epoch, iter, sample]'
+
+ step_num = cfgt.get('step_num' , None)
+ gradacc_every = cfgt.get('gradacc_every', 1 )
+ log_every = cfgt.get('log_every' , None)
+ ckpt_every = cfgt.get('ckpt_every' , None)
+ eval_start = cfgt.get('eval_start' , 0 )
+ eval_every = cfgt.get('eval_every' , None)
+
+ if paras.get('resume_step', None) is not None:
+ resume_step = paras['resume_step']
+ assert step_type == resume_step['type']
+ epochn = resume_step['epochn']
+ itern = resume_step['itern']
+ itern_local = itern * gradacc_every
+ samplen = resume_step['samplen']
+ del paras['resume_step']
+
+ trainloader = paras['trainloader']
+ if trainloader is None:
+ import itertools
+ trainloader = itertools.cycle([None])
+ optimizer = paras['optimizer']
+ scheduler = paras['scheduler']
+ net = paras['net']
+
+ GRANK, LRANK, NRANK = sync.get_rank('all')
+ GWSIZE, LWSIZE, NODES = sync.get_world_size('all')
+
+ weight_path = osp.join(cfgt.log_dir, 'weight')
+ if (GRANK==0) and (not osp.isdir(weight_path)):
+ os.makedirs(weight_path)
+ if (GRANK==0) and (cfgt.save_init_model):
+ self.save(net, is_init=True, step=0, optimizer=optimizer)
+
+ epoch_time = timeit.default_timer()
+ end_flag = False
+ net.train()
+
+ while True:
+ if step_type == 'epoch':
+ lr = scheduler[epochn] if scheduler is not None else None
+ for batch in trainloader:
+ # so first element of batch (usually image) can be [tensor]
+ if batch is None:
+ bs = cfgt.batch_size_per_gpu
+ elif not isinstance(batch[0], list):
+ bs = batch[0].shape[0]
+ else:
+ bs = len(batch[0])
+ if cfgt.skip_partial_batch and (bs != cfgt.batch_size_per_gpu):
+ continue
+
+ itern_local_next = itern_local + 1
+ samplen_next = samplen + bs*GWSIZE
+
+ if step_type == 'iter':
+ lr = scheduler[itern] if scheduler is not None else None
+ grad_update = itern_local%gradacc_every==(gradacc_every-1)
+ elif step_type == 'sample':
+ lr = scheduler[samplen] if scheduler is not None else None
+ # TODO:
+ # grad_update = samplen%gradacc_every==(gradacc_every-1)
+
+ itern_next = itern + 1 if grad_update else itern
+
+ # timeDebug = timeit.default_timer()
+ paras_new = self.main(
+ batch=batch,
+ lr=lr,
+ itern_local=itern_local,
+ itern=itern,
+ epochn=epochn,
+ samplen=samplen,
+ isinit=False,
+ grad_update=grad_update,
+ **paras)
+ # print_log(timeit.default_timer() - timeDebug)
+
+ paras.update(paras_new)
+ logm.accumulate(bs, **paras['log_info'])
+
+ #######
+ # log #
+ #######
+
+ display_flag = False
+ if log_every is not None:
+ display_i = (itern//log_every) != (itern_next//log_every)
+ display_s = (samplen//log_every) != (samplen_next//log_every)
+ display_flag = (display_i and (step_type=='iter')) \
+ or (display_s and (step_type=='sample'))
+
+ if display_flag:
+ tbstep = itern_next if step_type=='iter' else samplen_next
+ console_info = logm.train_summary(
+ itern_next, epochn, samplen_next, lr, tbstep=tbstep)
+ logm.clear()
+ print_log(console_info)
+
+ ########
+ # eval #
+ ########
+
+ eval_flag = False
+ if (self.nested_eval_stage is not None) and (eval_every is not None) and (NRANK == 0):
+ if step_type=='iter':
+ eval_flag = (itern//eval_every) != (itern_next//eval_every)
+ eval_flag = eval_flag and (itern_next>=eval_start)
+ eval_flag = eval_flag or itern_local==0
+ if step_type=='sample':
+ eval_flag = (samplen//eval_every) != (samplen_next//eval_every)
+ eval_flag = eval_flag and (samplen_next>=eval_start)
+ eval_flag = eval_flag or samplen==0
+
+ if eval_flag:
+ eval_cnt = itern_next if step_type=='iter' else samplen_next
+ net = self.set_model(net, 'eval')
+ rv = self.nested_eval_stage(
+ eval_cnt=eval_cnt, **paras)
+ rv = rv.get('eval_rv', None)
+ if rv is not None:
+ logm.tensorboard_log(eval_cnt, rv, mode='eval')
+ if self.is_better(rv):
+ self.rv_keep = rv
+ if GRANK==0:
+ step = {'epochn':epochn, 'itern':itern_next,
+ 'samplen':samplen_next, 'type':step_type, }
+ self.save(net, is_best=True, step=step, optimizer=optimizer)
+ net = self.set_model(net, 'train')
+
+ ########
+ # ckpt #
+ ########
+
+ ckpt_flag = False
+ if (GRANK==0) and (ckpt_every is not None):
+ # not distributed
+ ckpt_i = (itern//ckpt_every) != (itern_next//ckpt_every)
+ ckpt_s = (samplen//ckpt_every) != (samplen_next//ckpt_every)
+ ckpt_flag = (ckpt_i and (step_type=='iter')) \
+ or (ckpt_s and (step_type=='sample'))
+
+ if ckpt_flag:
+ if step_type == 'iter':
+ print_log('Checkpoint... {}'.format(itern_next))
+ step = {'epochn':epochn, 'itern':itern_next,
+ 'samplen':samplen_next, 'type':step_type, }
+ self.save(net, itern=itern_next, step=step, optimizer=optimizer)
+ else:
+ print_log('Checkpoint... {}'.format(samplen_next))
+ step = {'epochn':epochn, 'itern':itern_next,
+ 'samplen':samplen_next, 'type':step_type, }
+ self.save(net, samplen=samplen_next, step=step, optimizer=optimizer)
+
+ #######
+ # end #
+ #######
+
+ itern_local = itern_local_next
+ itern = itern_next
+ samplen = samplen_next
+
+ if step_type is not None:
+ end_flag = (itern>=step_num and (step_type=='iter')) \
+ or (samplen>=step_num and (step_type=='sample'))
+ if end_flag:
+ break
+ # loop end
+
+ epochn += 1
+ print_log('Epoch {} time:{:.2f}s.'.format(
+ epochn, timeit.default_timer()-epoch_time))
+ epoch_time = timeit.default_timer()
+
+ if end_flag:
+ break
+ elif step_type != 'epoch':
+ # This is temporarily added to resolve the data issue
+ trainloader = self.trick_update_trainloader(trainloader)
+ continue
+
+ #######
+ # log #
+ #######
+
+ display_flag = False
+ if (log_every is not None) and (step_type=='epoch'):
+ display_flag = (epochn==1) or (epochn%log_every==0)
+
+ if display_flag:
+ console_info = logm.train_summary(
+ itern, epochn, samplen, lr, tbstep=epochn)
+ logm.clear()
+ print_log(console_info)
+
+ ########
+ # eval #
+ ########
+
+ eval_flag = False
+ if (self.nested_eval_stage is not None) and (eval_every is not None) \
+ and (step_type=='epoch') and (NRANK==0):
+ eval_flag = (epochn%eval_every==0) and (itern_next>=eval_start)
+ eval_flag = (epochn==1) or eval_flag
+
+ if eval_flag:
+ net = self.set_model(net, 'eval')
+ rv = self.nested_eval_stage(
+ eval_cnt=epochn,
+ **paras)['eval_rv']
+ if rv is not None:
+ logm.tensorboard_log(epochn, rv, mode='eval')
+ if self.is_better(rv):
+ self.rv_keep = rv
+ if (GRANK==0):
+ step = {'epochn':epochn, 'itern':itern,
+ 'samplen':samplen, 'type':step_type, }
+ self.save(net, is_best=True, step=step, optimizer=optimizer)
+ net = self.set_model(net, 'train')
+
+ ########
+ # ckpt #
+ ########
+
+ ckpt_flag = False
+ if (ckpt_every is not None) and (GRANK==0) and (step_type=='epoch'):
+ # not distributed
+ ckpt_flag = epochn%ckpt_every==0
+
+ if ckpt_flag:
+ print_log('Checkpoint... {}'.format(itern_next))
+ step = {'epochn':epochn, 'itern':itern,
+ 'samplen':samplen, 'type':step_type, }
+ self.save(net, epochn=epochn, step=step, optimizer=optimizer)
+
+ #######
+ # end #
+ #######
+ if (step_type=='epoch') and (epochn>=step_num):
+ break
+ # loop end
+
+ # This is temporarily added to resolve the data issue
+ trainloader = self.trick_update_trainloader(trainloader)
+
+ logm.tensorboard_close()
+ return {}
+
+ def main(self, **paras):
+ raise NotImplementedError
+
+ def trick_update_trainloader(self, trainloader):
+ return trainloader
+
+ def save_model(self, net, path_noext, **paras):
+ cfgt = cfguh().cfg.train
+ path = path_noext+'.pth'
+ if isinstance(net, (torch.nn.DataParallel,
+ torch.nn.parallel.DistributedDataParallel)):
+ netm = net.module
+ else:
+ netm = net
+ torch.save(netm.state_dict(), path)
+ print_log('Saving model file {0}'.format(path))
+
+ def save(self, net, itern=None, epochn=None, samplen=None,
+ is_init=False, is_best=False, is_last=False, **paras):
+ exid = cfguh().cfg.env.experiment_id
+ cfgt = cfguh().cfg.train
+ cfgm = cfguh().cfg.model
+ if isinstance(net, (torch.nn.DataParallel,
+ torch.nn.parallel.DistributedDataParallel)):
+ netm = net.module
+ else:
+ netm = net
+ net_symbol = cfgm.symbol
+
+ check = sum([
+ itern is not None, samplen is not None, epochn is not None,
+ is_init, is_best, is_last])
+ assert check<2
+
+ if itern is not None:
+ path_noexp = '{}_{}_iter_{}'.format(exid, net_symbol, itern)
+ elif samplen is not None:
+ path_noexp = '{}_{}_samplen_{}'.format(exid, net_symbol, samplen)
+ elif epochn is not None:
+ path_noexp = '{}_{}_epoch_{}'.format(exid, net_symbol, epochn)
+ elif is_init:
+ path_noexp = '{}_{}_init'.format(exid, net_symbol)
+ elif is_best:
+ path_noexp = '{}_{}_best'.format(exid, net_symbol)
+ elif is_last:
+ path_noexp = '{}_{}_last'.format(exid, net_symbol)
+ else:
+ path_noexp = '{}_{}_default'.format(exid, net_symbol)
+
+ path_noexp = osp.join(cfgt.log_dir, 'weight', path_noexp)
+ self.save_model(net, path_noexp, **paras)
+
+class eval_stage(object):
+ def __init__(self):
+ self.evaluator = None
+
+ def create_dir(self, path):
+ grank = sync.get_rank('global')
+ if (not osp.isdir(path)) and (grank == 0):
+ os.makedirs(path)
+ sync.nodewise_sync().barrier()
+
+ def __call__(self,
+ evalloader,
+ net,
+ **paras):
+ cfgt = cfguh().cfg.eval
+ local_rank = sync.get_rank('local')
+ if self.evaluator is None:
+ evaluator = get_evaluator()(cfgt.evaluator)
+ self.evaluator = evaluator
+ else:
+ evaluator = self.evaluator
+
+ time_check = timeit.default_timer()
+
+ for idx, batch in enumerate(evalloader):
+ rv = self.main(batch, net)
+ evaluator.add_batch(**rv)
+ if cfgt.output_result:
+ try:
+ self.output_f(**rv, cnt=paras['eval_cnt'])
+ except:
+ self.output_f(**rv)
+ if idx%cfgt.log_display == cfgt.log_display-1:
+ print_log('processed.. {}, Time:{:.2f}s'.format(
+ idx+1, timeit.default_timer() - time_check))
+ time_check = timeit.default_timer()
+ # break
+
+ evaluator.set_sample_n(len(evalloader.dataset))
+ eval_rv = evaluator.compute()
+ if local_rank == 0:
+ evaluator.one_line_summary()
+ evaluator.save(cfgt.log_dir)
+ evaluator.clear_data()
+ return {
+ 'eval_rv' : eval_rv
+ }
+
+class exec_container(object):
+ """
+ This is the base functor for all types of executions.
+ One execution can have multiple stages,
+ but are only allowed to use the same
+ config, network, dataloader.
+ Thus, in most of the cases, one exec_container is one
+ training/evaluation/demo...
+ If DPP is in use, this functor should be spawn.
+ """
+ def __init__(self,
+ cfg,
+ **kwargs):
+ self.cfg = cfg
+ self.registered_stages = []
+ self.node_rank = None
+ self.local_rank = None
+ self.global_rank = None
+ self.local_world_size = None
+ self.global_world_size = None
+ self.nodewise_sync_global_obj = sync.nodewise_sync_global()
+
+ def register_stage(self, stage):
+ self.registered_stages.append(stage)
+
+ def __call__(self,
+ local_rank,
+ **kwargs):
+ cfg = self.cfg
+ cfguh().save_cfg(cfg)
+
+ self.node_rank = cfg.env.node_rank
+ self.local_rank = local_rank
+ self.nodes = cfg.env.nodes
+ self.local_world_size = cfg.env.gpu_count
+
+ self.global_rank = self.local_rank + self.node_rank * self.local_world_size
+ self.global_world_size = self.nodes * self.local_world_size
+
+ print('init {}/{}'.format(self.global_rank, self.global_world_size))
+ dist.init_process_group(
+ backend = cfg.env.dist_backend,
+ init_method = cfg.env.dist_url,
+ rank = self.global_rank,
+ world_size = self.global_world_size,)
+ torch.cuda.set_device(local_rank)
+ sync.nodewise_sync().copy_global(self.nodewise_sync_global_obj).local_init()
+
+ if isinstance(cfg.env.rnd_seed, int):
+ random.seed(cfg.env.rnd_seed + self.global_rank + 200)
+ np.random.seed(cfg.env.rnd_seed + self.global_rank + 100)
+ torch.manual_seed(cfg.env.rnd_seed + self.global_rank)
+
+ time_start = timeit.default_timer()
+
+ para = {'itern_total' : 0,}
+ dl_para = self.prepare_dataloader()
+ assert isinstance(dl_para, dict)
+ para.update(dl_para)
+
+ md_para = self.prepare_model()
+ assert isinstance(md_para, dict)
+ para.update(md_para)
+
+ for stage in self.registered_stages:
+ stage_para = stage(**para)
+ if stage_para is not None:
+ para.update(stage_para)
+
+ if self.global_rank==0:
+ self.save_last_model(**para)
+
+ print_log(
+ 'Total {:.2f} seconds'.format(timeit.default_timer() - time_start))
+ dist.destroy_process_group()
+
+ def prepare_dataloader(self):
+ """
+ Prepare the dataloader from config.
+ """
+ return {
+ 'trainloader' : None,
+ 'evalloader' : None}
+
+ def prepare_model(self):
+ """
+ Prepare the model from config.
+ """
+ return {'net' : None}
+
+ def save_last_model(self, **para):
+ return
+
+ def destroy(self):
+ self.nodewise_sync_global_obj.destroy()
+
+class train(exec_container):
+ def prepare_dataloader(self):
+ cfg = cfguh().cfg
+ trainset = get_dataset()(cfg.train.dataset)
+ trainloader = None
+ if trainset is not None:
+ sampler = get_sampler()(
+ dataset=trainset, cfg=cfg.train.dataset.get('sampler', 'default_train'))
+ trainloader = torch.utils.data.DataLoader(
+ trainset,
+ batch_size = cfg.train.batch_size_per_gpu,
+ sampler = sampler,
+ num_workers = cfg.train.dataset_num_workers_per_gpu,
+ drop_last = False,
+ pin_memory = cfg.train.dataset.get('pin_memory', False),
+ collate_fn = collate(),)
+
+ evalloader = None
+ if 'eval' in cfg:
+ evalset = get_dataset()(cfg.eval.dataset)
+ if evalset is not None:
+ sampler = get_sampler()(
+ dataset=evalset, cfg=cfg.eval.dataset.get('sampler', 'default_eval'))
+ evalloader = torch.utils.data.DataLoader(
+ evalset,
+ batch_size = cfg.eval.batch_size_per_gpu,
+ sampler = sampler,
+ num_workers = cfg.eval.dataset_num_workers_per_gpu,
+ drop_last = False,
+ pin_memory = cfg.eval.dataset.get('pin_memory', False),
+ collate_fn = collate(),)
+
+ return {
+ 'trainloader' : trainloader,
+ 'evalloader' : evalloader,}
+
+ def prepare_model(self):
+ cfg = cfguh().cfg
+ net = get_model()(cfg.model)
+ find_unused_parameters=cfg.model.get('find_unused_parameters', False)
+ if cfg.env.cuda:
+ net.to(self.local_rank)
+ net = torch.nn.parallel.DistributedDataParallel(
+ net, device_ids=[self.local_rank],
+ find_unused_parameters=find_unused_parameters)
+ net.train()
+ scheduler = get_scheduler()(cfg.train.scheduler)
+ optimizer = get_optimizer()(net, cfg.train.optimizer)
+ return {
+ 'net' : net,
+ 'optimizer' : optimizer,
+ 'scheduler' : scheduler,}
+
+ def save_last_model(self, **para):
+ cfgt = cfguh().cfg.train
+ net = para['net']
+ net_symbol = cfguh().cfg.model.symbol
+ if isinstance(net, (torch.nn.DataParallel,
+ torch.nn.parallel.DistributedDataParallel)):
+ netm = net.module
+ else:
+ netm = net
+ path = osp.join(cfgt.log_dir, '{}_{}_last.pth'.format(
+ cfgt.experiment_id, net_symbol))
+ torch.save(netm.state_dict(), path)
+ print_log('Saving model file {0}'.format(path))
+
+class eval(exec_container):
+ def prepare_dataloader(self):
+ cfg = cfguh().cfg
+ evalloader = None
+ if cfg.eval.get('dataset', None) is not None:
+ evalset = get_dataset()(cfg.eval.dataset)
+ if evalset is None:
+ return
+ sampler = get_sampler()(
+ dataset=evalset, cfg=getattr(cfg.eval.dataset, 'sampler', 'default_eval'))
+ evalloader = torch.utils.data.DataLoader(
+ evalset,
+ batch_size = cfg.eval.batch_size_per_gpu,
+ sampler = sampler,
+ num_workers = cfg.eval.dataset_num_workers_per_gpu,
+ drop_last = False,
+ pin_memory = False,
+ collate_fn = collate(), )
+ return {
+ 'trainloader' : None,
+ 'evalloader' : evalloader,}
+
+ def prepare_model(self):
+ cfg = cfguh().cfg
+ net = get_model()(cfg.model)
+ if cfg.env.cuda:
+ net.to(self.local_rank)
+ net = torch.nn.parallel.DistributedDataParallel(
+ net, device_ids=[self.local_rank],
+ find_unused_parameters=True)
+ net.eval()
+ return {'net' : net,}
+
+ def save_last_model(self, **para):
+ return
+
+###############
+# some helper #
+###############
+
+def torch_to_numpy(*argv):
+ if len(argv) > 1:
+ data = list(argv)
+ else:
+ data = argv[0]
+
+ if isinstance(data, torch.Tensor):
+ return data.to('cpu').detach().numpy()
+ elif isinstance(data, (list, tuple)):
+ out = []
+ for di in data:
+ out.append(torch_to_numpy(di))
+ return out
+ elif isinstance(data, dict):
+ out = {}
+ for ni, di in data.items():
+ out[ni] = torch_to_numpy(di)
+ return out
+ else:
+ return data
+
+import importlib
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a73c60900489caceefd7dc69756c35ed5c37a3fc
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+--extra-index-url https://download.pytorch.org/whl/cu117
+
+torch==2.0.0+cu117
+torchvision==0.15.1
+
+pyyaml==5.4.1
+easydict==1.9
+protobuf==3.20.3
+fsspec==2022.7.1
+
+tqdm==4.60.0
+transformers==4.24.0
+torchmetrics==0.7.3
+
+einops==0.3.0
+omegaconf==2.1.1
+huggingface-hub==0.11.1
+gradio==3.17.1
+
+safetensors==0.3.1
+
+opencv-python
+scikit-image