diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea4c52f916b518b093b04331f1620d758507a34d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,793 @@
+import gradio as gr
+import torch
+from omegaconf import OmegaConf
+from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
+
+import json
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+from functools import partial
+from collections import Counter
+import math
+import gc
+
+from gradio import processing_utils
+from typing import Optional
+
+import warnings
+
+from datetime import datetime
+
+from example_component import create_examples
+
+from huggingface_hub import hf_hub_download
+hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
+import cv2
+import sys
+sys.tracebacklimit = 0
+
+
+def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
+ return torch.load(cache_file, map_location='cpu')
+
+def load_ckpt_config_from_hf(modality):
+ ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model')
+ config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config')
+ return ckpt, config
+
+
+def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
+ pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
+ config = OmegaConf.create( config["_content"] ) # config used in training
+ config.alpha_scale = 1.0
+
+ if common_instances is None:
+ common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
+ common_instances = load_common_ckpt(config, common_ckpt)
+
+ loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances)
+
+ return loaded_model_list, common_instances
+
+
+class Instance:
+ def __init__(self, capacity = 2):
+ self.model_type = 'base'
+ self.loaded_model_list = {}
+ self.counter = Counter()
+ self.global_counter = Counter()
+ self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
+ 'gligen-generation-text-box',
+ is_inpaint=False, is_style=False, common_instances=None
+ )
+ self.capacity = capacity
+
+ def _log(self, model_type, batch_size, instruction, phrase_list):
+ self.counter[model_type] += 1
+ self.global_counter[model_type] += 1
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
+ current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
+ ))
+
+ def get_model(self, model_type, batch_size, instruction, phrase_list):
+ if model_type in self.loaded_model_list:
+ self._log(model_type, batch_size, instruction, phrase_list)
+ return self.loaded_model_list[model_type]
+
+ if self.capacity == len(self.loaded_model_list):
+ least_used_type = self.counter.most_common()[-1][0]
+ del self.loaded_model_list[least_used_type]
+ del self.counter[least_used_type]
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ self.loaded_model_list[model_type] = self._get_model(model_type)
+ self._log(model_type, batch_size, instruction, phrase_list)
+ return self.loaded_model_list[model_type]
+
+ def _get_model(self, model_type):
+ if model_type == 'base':
+ return ckpt_load_helper(
+ 'gligen-generation-text-box',
+ is_inpaint=False, is_style=False, common_instances=self.common_instances
+ )[0]
+ elif model_type == 'inpaint':
+ return ckpt_load_helper(
+ 'gligen-inpainting-text-box',
+ is_inpaint=True, is_style=False, common_instances=self.common_instances
+ )[0]
+ elif model_type == 'style':
+ return ckpt_load_helper(
+ 'gligen-generation-text-image-box',
+ is_inpaint=False, is_style=True, common_instances=self.common_instances
+ )[0]
+
+ assert False
+
+instance = Instance()
+
+
+def load_clip_model():
+ from transformers import CLIPProcessor, CLIPModel
+ version = "openai/clip-vit-large-patch14"
+ model = CLIPModel.from_pretrained(version).cuda()
+ processor = CLIPProcessor.from_pretrained(version)
+
+ return {
+ 'version': version,
+ 'model': model,
+ 'processor': processor,
+ }
+
+clip_model = load_clip_model()
+
+
+class ImageMask(gr.components.Image):
+ """
+ Sets: source="canvas", tool="sketch"
+ """
+
+ is_template = True
+
+ def __init__(self, **kwargs):
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
+
+ def preprocess(self, x):
+ if x is None:
+ return x
+ if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
+
+ decode_image = processing_utils.decode_base64_to_image(x)
+ width, height = decode_image.size
+ img = np.asarray(decode_image)
+ return {'image':img, 'mask':binarize_2(img)}
+
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
+
+ mask[..., -1] = 255
+ mask = self.postprocess(mask)
+ x = {'image': x, 'mask': mask}
+ print('vao preprocess-------------------------')
+ hh = super().preprocess(x)
+ if (hh['image'].min()!=255) and (hh['mask'][:,:,:3].max()==0):
+
+ hh['mask'] = binarize_2(hh['image'])
+
+ return hh
+
+
+class Blocks(gr.Blocks):
+
+ def __init__(
+ self,
+ theme: str = "default",
+ analytics_enabled: Optional[bool] = None,
+ mode: str = "blocks",
+ title: str = "Gradio",
+ css: Optional[str] = None,
+ **kwargs,
+ ):
+
+ self.extra_configs = {
+ 'thumbnail': kwargs.pop('thumbnail', ''),
+ 'url': kwargs.pop('url', 'https://gradio.app/'),
+ 'creator': kwargs.pop('creator', '@teamGradio'),
+ }
+
+ super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
+ warnings.filterwarnings("ignore")
+
+ def get_config_file(self):
+ config = super(Blocks, self).get_config_file()
+
+ for k, v in self.extra_configs.items():
+ config[k] = v
+
+ return config
+
+'''
+inference model
+'''
+
+# @torch.no_grad()
+def inference(task, language_instruction, phrase_list, location_list, inpainting_boxes_nodrop, image,
+ alpha_sample, guidance_scale, batch_size,
+ fix_seed, rand_seed, actual_mask, style_image,
+ *args, **kwargs):
+ # import pdb; pdb.set_trace()
+
+ # grounding_instruction = json.loads(grounding_instruction)
+ # phrase_list, location_list = [], []
+ # for k, v in grounding_instruction.items():
+ # phrase_list.append(k)
+ # location_list.append(v)
+
+ placeholder_image = Image.open('images/teddy.jpg').convert("RGB")
+ image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
+
+ batch_size = int(batch_size)
+ if not 1 <= batch_size <= 4:
+ batch_size = 1
+
+ if style_image == None:
+ has_text_mask = 1
+ has_image_mask = 0 # then we hack above 'image_list'
+ else:
+ valid_phrase_len = len(phrase_list)
+
+ phrase_list += ['placeholder']
+ has_text_mask = [1]*valid_phrase_len + [0]
+
+ image_list = [placeholder_image]*valid_phrase_len + [style_image]
+ has_image_mask = [0]*valid_phrase_len + [1]
+
+ location_list += [ [0.0, 0.0, 1, 0.01] ] # style image grounding location
+
+ instruction = dict(
+ prompt = language_instruction,
+ phrases = phrase_list,
+ images = image_list,
+ locations = location_list,
+ alpha_type = [alpha_sample, 0, 1.0 - alpha_sample],
+ has_text_mask = has_text_mask,
+ has_image_mask = has_image_mask,
+ save_folder_name = language_instruction,
+ guidance_scale = guidance_scale,
+ batch_size = batch_size,
+ fix_seed = bool(fix_seed),
+ rand_seed = int(rand_seed),
+ actual_mask = actual_mask,
+ inpainting_boxes_nodrop = inpainting_boxes_nodrop,
+ )
+
+ get_model = partial(instance.get_model,
+ batch_size=batch_size,
+ instruction=language_instruction,
+ phrase_list=phrase_list)
+
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
+ if task == 'User provide boxes' or 'Available boxes':
+ if style_image == None:
+ result = grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
+ torch.cuda.empty_cache()
+ return result
+ else:
+ return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
+
+
+def draw_box(boxes=[], texts=[], img=None):
+ if len(boxes) == 0 and img is None:
+ return None
+
+ if img is None:
+ img = Image.new('RGB', (512, 512), (255, 255, 255))
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
+ draw = ImageDraw.Draw(img)
+ font = ImageFont.truetype("DejaVuSansMono.ttf", size=18)
+ for bid, box in enumerate(boxes):
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
+ anno_text = texts[bid]
+ draw.rectangle([box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
+ draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size*1.2)], anno_text, font=font, fill=(255,255,255))
+ return img
+
+def get_concat(ims):
+ if len(ims) == 1:
+ n_col = 1
+ else:
+ n_col = 2
+ n_row = math.ceil(len(ims) / 2)
+ dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white")
+ for i, im in enumerate(ims):
+ row_id = i // n_col
+ col_id = i % n_col
+ dst.paste(im, (im.width * col_id, im.height * row_id))
+ return dst
+
+
+def auto_append_grounding(language_instruction, grounding_texts):
+ for grounding_text in grounding_texts:
+ if grounding_text.lower() not in language_instruction.lower() and grounding_text != 'auto':
+ language_instruction += "; " + grounding_text
+ return language_instruction
+
+
+
+
+def generate(task, language_instruction, grounding_texts, sketch_pad,
+ alpha_sample, guidance_scale, batch_size,
+ fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
+ state):
+
+ if 'boxes' not in state:
+ state['boxes'] = []
+
+ boxes = state['boxes']
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
+ # assert len(boxes) == len(grounding_texts)
+ if len(boxes) != len(grounding_texts):
+ if len(boxes) < len(grounding_texts):
+ raise ValueError("""The number of boxes should be equal to the number of grounding objects.
+Number of boxes drawn: {}, number of grounding tokens: {}.
+Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
+ grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
+
+ boxes = (np.asarray(boxes) / 512).tolist()
+ grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
+ image = None
+ actual_mask = None
+
+
+ if append_grounding:
+ language_instruction = auto_append_grounding(language_instruction, grounding_texts)
+
+ gen_images, gen_overlays = inference(
+ task, language_instruction, grounding_texts,boxes, boxes, image,
+ alpha_sample, guidance_scale, batch_size,
+ fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model,
+ )
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
+ gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
+
+ return gen_images + [state]
+
+
+def binarize(x):
+ return (x != 0).astype('uint8') * 255
+def binarize_2(x):
+ gray_image = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
+ return (gray_image!=255).astype('uint8') * 255
+
+def sized_center_crop(img, cropx, cropy):
+ y, x = img.shape[:2]
+ startx = x // 2 - (cropx // 2)
+ starty = y // 2 - (cropy // 2)
+ return img[starty:starty+cropy, startx:startx+cropx]
+
+def sized_center_fill(img, fill, cropx, cropy):
+ y, x = img.shape[:2]
+ startx = x // 2 - (cropx // 2)
+ starty = y // 2 - (cropy // 2)
+ img[starty:starty+cropy, startx:startx+cropx] = fill
+ return img
+
+def sized_center_mask(img, cropx, cropy):
+ y, x = img.shape[:2]
+ startx = x // 2 - (cropx // 2)
+ starty = y // 2 - (cropy // 2)
+ center_region = img[starty:starty+cropy, startx:startx+cropx].copy()
+ img = (img * 0.2).astype('uint8')
+ img[starty:starty+cropy, startx:startx+cropx] = center_region
+ return img
+
+def center_crop(img, HW=None, tgt_size=(512, 512)):
+ if HW is None:
+ H, W = img.shape[:2]
+ HW = min(H, W)
+ img = sized_center_crop(img, HW, HW)
+ img = Image.fromarray(img)
+ img = img.resize(tgt_size)
+ return np.array(img)
+
+def draw(task, input, grounding_texts, new_image_trigger, state, generate_parsed, box_image):
+ print('input', generate_parsed)
+
+ if type(input) == dict:
+ image = input['image']
+ mask = input['mask']
+ if generate_parsed==1:
+ generate_parsed = 0
+ # import pdb; pdb.set_trace()
+ print('do nothing')
+
+ return [box_image, new_image_trigger, 1., state, generate_parsed]
+
+ else:
+ mask = input
+
+ if mask.ndim == 3:
+ mask = mask[..., 0]
+
+ image_scale = 1.0
+
+ print('vao draw--------------------')
+ mask = binarize(mask)
+ if mask.shape != (512, 512):
+ # assert False, "should not receive any non- 512x512 masks."
+ if 'original_image' in state and state['original_image'].shape[:2] == mask.shape:
+ mask = center_crop(mask, state['inpaint_hw'])
+ image = center_crop(state['original_image'], state['inpaint_hw'])
+ else:
+ mask = np.zeros((512, 512), dtype=np.uint8)
+ mask = binarize(mask)
+
+ if type(mask) != np.ndarray:
+ mask = np.array(mask)
+ #
+ if mask.sum() == 0:
+ state = {}
+ print('delete state')
+
+ if True:
+ image = None
+ else:
+ image = Image.fromarray(image)
+
+ if 'boxes' not in state:
+ state['boxes'] = []
+
+ if 'masks' not in state or len(state['masks']) == 0 :
+ state['masks'] = []
+ last_mask = np.zeros_like(mask)
+ else:
+ last_mask = state['masks'][-1]
+
+ if type(mask) == np.ndarray and mask.size > 1 :
+ diff_mask = mask - last_mask
+ else:
+ diff_mask = np.zeros([])
+
+ if diff_mask.sum() > 0:
+ x1x2 = np.where(diff_mask.max(0) > 1)[0]
+ y1y2 = np.where(diff_mask.max(1) > 1)[0]
+ y1, y2 = y1y2.min(), y1y2.max()
+ x1, x2 = x1x2.min(), x1x2.max()
+
+ if (x2 - x1 > 5) and (y2 - y1 > 5):
+ state['masks'].append(mask.copy())
+ state['boxes'].append((x1, y1, x2, y2))
+
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
+ grounding_texts = [x for x in grounding_texts if len(x) > 0]
+ if len(grounding_texts) < len(state['boxes']):
+ grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))]
+
+ box_image = draw_box(state['boxes'], grounding_texts, image)
+ generate_parsed = 0
+
+ return [box_image, new_image_trigger, image_scale, state, generate_parsed]
+
+def change_state(bboxes,layout, state, instruction, trigger_stage, boxes):
+ if trigger_stage ==0 :
+ return [boxes, state, 0]
+ # mask =
+ state['boxes'] = []
+ state['masks'] = []
+ image = None
+ list_boxes = bboxes.split('/')
+ result =[]
+ for b in list_boxes:
+ ints = b[1:-1].split(',')
+ l = []
+ for i in ints:
+ l.append(int(i))
+ result.append(l)
+ print('run change state')
+
+ for box in result:
+ state['boxes'].append(box)
+ grounding_texts = [x.strip() for x in instruction.split(';')]
+ grounding_texts = [x for x in grounding_texts if len(x) > 0]
+ if len(grounding_texts) < len(result):
+ grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(result))]
+
+ box_image = draw_box(result, grounding_texts)
+
+ mask = binarize_2(layout['image'])
+ state['masks'].append(mask.copy())
+ # print('done change state', state)
+ print('done change state')
+ # import pdb; pdb.set_trace()
+ return [box_image,state, trigger_stage]
+
+def example_click(name, grounding_instruction, instruction, bboxes,generate_parsed, trigger_parsed):
+
+ list_boxes = bboxes.split('/')
+ result =[]
+
+ for b in list_boxes:
+ ints = b[1:-1].split(',')
+ l = []
+ for i in ints:
+ l.append(int(i))
+ result.append(l)
+ print('run change state')
+
+ box_image = draw_box(result, instruction)
+ trigger_parsed += 1
+ print('done the example click')
+ return [box_image, trigger_parsed]
+
+def clear(task, sketch_pad_trigger, batch_size, state,trigger_stage, switch_task=False):
+
+ sketch_pad_trigger = sketch_pad_trigger + 1
+ trigger_stage = 0
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
+ out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
+ state = {}
+ return [None, sketch_pad_trigger, None, 1.0] + out_images + [state] + [trigger_stage]
+
+css = """
+#img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
+{
+ height: var(--height) !important;
+ max-height: var(--height) !important;
+ min-height: var(--height) !important;
+}
+#paper-info a {
+ color:#008AD7;
+ text-decoration: none;
+}
+#paper-info a:hover {
+ cursor: pointer;
+ text-decoration: none;
+}
+#my_image > div.fixed-height
+{
+ height: var(--height) !important;
+}
+"""
+
+rescale_js = """
+function(x) {
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
+ let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
+ const image_width = root.querySelector('#img2img_image').clientWidth;
+ const target_height = parseInt(image_width * image_scale);
+ document.body.style.setProperty('--height', `${target_height}px`);
+ root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
+ root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
+ return x;
+}
+"""
+# [Paper]
+with Blocks(
+ css=css,
+ analytics_enabled=False,
+ title="Attention-refocusing demo",
+) as main:
+ description = """
+ Grounded Text-to-Image Synthesis with Attention Refocusing
+
+
+ [Project Page]
+
+ [GitHub]
+
+
+
+ To identify the areas of interest based on specific spatial parameters, you need to (1) ⌨️ input the names of the concepts you're interested in Grounding Instruction, and (2) 🖱️ draw their corresponding bounding boxes using Sketch Pad -- the parsed boxes will automatically be showed up once you've drawn them.
+
+ For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
+
+ """
+ gr.HTML(description)
+
+ with gr.Row():
+ with gr.Column(scale=4):
+ sketch_pad_trigger = gr.Number(value=0, visible=False)
+ sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
+ trigger_stage = gr.Number(value=0, visible=False)
+
+ init_white_trigger = gr.Number(value=0, visible=False)
+ image_scale = gr.Number(value=1.0, elem_id="image_scale", visible=False)
+ new_image_trigger = gr.Number(value=0, visible=False)
+ text_box = gr.Textbox(visible=False)
+ generate_parsed = gr.Number(value=0, visible=False)
+
+ task = gr.Radio(
+ choices=["Available boxes", 'User provide boxes'],
+ type="value",
+ value="User provide boxes",
+ label="Task",
+ visible=False
+
+ )
+ language_instruction = gr.Textbox(
+ label="Language instruction",
+ )
+ grounding_instruction = gr.Textbox(
+ label="Grounding instruction (Separated by semicolon)",
+ )
+ with gr.Row():
+ sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
+ out_imagebox = gr.Image(type="pil",elem_id="my_image" ,label="Parsed Sketch Pad", shape=(512,512))
+ with gr.Row():
+ clear_btn = gr.Button(value='Clear')
+ gen_btn = gr.Button(value='Generate')
+ with gr.Row():
+ parsed_btn = gr.Button(value='generate parsed boxes', visible=False)
+
+ with gr.Accordion("Advanced Options", open=False):
+ with gr.Column():
+ alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)")
+ guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
+ batch_size = gr.Slider(minimum=1, maximum=4,visible=False, step=1, value=1, label="Number of Samples")
+ append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
+ use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
+ with gr.Row():
+ fix_seed = gr.Checkbox(value=True, label="Fixed seed")
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed")
+
+ with gr.Row():
+ use_style_cond = gr.Checkbox(value=False,visible=False, label="Enable Style Condition")
+ style_cond_image = gr.Image(type="pil",visible=False, label="Style Condition", interactive=True)
+ with gr.Column(scale=4):
+ gr.HTML('Generated Images')
+ with gr.Row():
+ out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
+ out_gen_2 = gr.Image(type="pil", visible=False, show_label=False)
+ with gr.Row():
+ out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
+ out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
+
+ state = gr.State({})
+
+
+ class Controller:
+ def __init__(self):
+ self.calls = 0
+ self.tracks = 0
+ self.resizes = 0
+ self.scales = 0
+
+ def init_white(self, init_white_trigger):
+ self.calls += 1
+ return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger+1
+
+ def change_n_samples(self, n_samples):
+ blank_samples = n_samples % 2 if n_samples > 1 else 0
+ return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
+ + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
+
+ controller = Controller()
+ main.load(
+ lambda x:x+1,
+ inputs=sketch_pad_trigger,
+ outputs=sketch_pad_trigger,
+ queue=False)
+
+ sketch_pad.edit(
+ draw,
+ inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state, generate_parsed, out_imagebox],
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state, generate_parsed],
+ queue=False,
+ )
+ trigger_stage.change(
+ change_state,
+ inputs=[text_box,sketch_pad, state, grounding_instruction, trigger_stage,out_imagebox],
+ outputs=[out_imagebox,state,trigger_stage],
+ queue=True
+ )
+ grounding_instruction.change(
+ draw,
+ inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state, generate_parsed,out_imagebox],
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state, generate_parsed],
+ queue=False,
+ )
+ clear_btn.click(
+ clear,
+ inputs=[task, sketch_pad_trigger, batch_size,trigger_stage, state],
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state, trigger_stage],
+ queue=False)
+
+ sketch_pad_trigger.change(
+ controller.init_white,
+ inputs=[init_white_trigger],
+ outputs=[sketch_pad, image_scale, init_white_trigger],
+ queue=False)
+
+ gen_btn.click(
+ generate,
+ inputs=[
+ task, language_instruction, grounding_instruction, sketch_pad,
+ alpha_sample, guidance_scale, batch_size,
+ fix_seed, rand_seed,
+ use_actual_mask,
+ append_grounding, style_cond_image,
+ state,
+ ],
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
+ queue=True
+ )
+ init_white_trigger.change(
+ None,
+ None,
+ init_white_trigger,
+ _js=rescale_js,
+ queue=False)
+ examples = [
+ [
+ 'guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg',
+ "a cat;a dog",
+ "a cat on the right of a dog",
+ '(291, 88, 481, 301)/(25, 64, 260, 391)',
+ 1, 1
+ ],
+ [
+ 'guide_imgs/0_a_bus_on_the_left_of_a_car.jpg',#'guide_imgs/0_a_bus_on_the_left_of_a_car.jpg',
+ "a bus;a car",
+ "a bus and a car",
+ '(8,128,266,384)/(300,196,502,316)', #'(8,128,266,384)', #/(300,196,502,316)
+ 1, 2
+ ],
+ [
+ 'guide_imgs/1_Two_cars_on_the_street..jpg',
+ "a car;a car",
+ "Two cars on the street.",
+ '(34, 98, 247, 264)/(271, 122, 481, 293)',
+ 1, 3
+ ],
+ [
+ 'guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg',
+ "an apple;an apple",
+ "two apples lay side by side on a wooden table, their glossy red and green skins glinting in the sunlight.",
+ '(40, 210, 235, 450)/(275, 210, 470, 450)',
+ 1, 4
+ ],
+ [
+ 'guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg',
+ "a banana;an apple",
+ "A banana on the left of an apple.",
+ '(62, 193, 225, 354)/(300, 184, 432, 329)',
+ 1, 5
+ ],
+ [
+ 'guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg',
+ "a pizza ;a suitcase",
+ "A pizza on the right of a suitcase.",
+ '(307, 112, 490, 280)/(41, 120, 244, 270)',
+ 1, 6
+ ],
+ [
+ 'guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg',
+ "a wine glass;a dog",
+ "A wine glass on top of a dog.",
+ '(206, 78, 306, 214)/(137, 222, 367, 432)',
+ 1, 7
+ ]
+ ,
+ [
+ 'guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg',
+ "a bicycle;a boat",
+ "A bicycle on top of a boat.",
+ '(185, 110, 335, 205)/(111, 228, 401, 373)',
+ 1, 8
+ ]
+ ,
+ [
+ 'guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg',
+ "a laptop;a teddy bear",
+ "A laptop on top of a teddy bear.",
+ '(180, 70, 332, 210)/(150, 240, 362, 420)',
+ 1, 9
+ ]
+ ,
+ [
+ 'guide_imgs/0_A_train_on_top_of_a_surfboard..jpg',
+ "a train;a surfboard",
+ "A train on top of a surfboard.",
+ '(130, 80, 385, 240)/(75, 260, 440, 450)',
+ 1, 10
+ ]
+ ]
+
+ with gr.Column():
+
+ create_examples(
+ examples=examples,
+ inputs=[sketch_pad, grounding_instruction,language_instruction , text_box, generate_parsed, trigger_stage],
+ outputs=None,
+ fn=None,
+ cache_examples=False,
+
+ )
+
+main.queue(concurrency_count=1, api_open=False)
+main.launch(share=False, show_api=False, show_error=True, debug=False, server_name="0.0.0.0")
diff --git a/dataset/__init__.py b/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dataset/__pycache__/__init__.cpython-38.pyc b/dataset/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05b0ef16d074d28a35ecbe025b6821120d8605da
Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-38.pyc differ
diff --git a/dataset/__pycache__/catalog.cpython-38.pyc b/dataset/__pycache__/catalog.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc5b0f96eb8b80206ac3e28b4e5237e9d1cf29a3
Binary files /dev/null and b/dataset/__pycache__/catalog.cpython-38.pyc differ
diff --git a/dataset/__pycache__/concat_dataset.cpython-38.pyc b/dataset/__pycache__/concat_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0d24742b9b477ac163ee9f1bd4a17b85f370051
Binary files /dev/null and b/dataset/__pycache__/concat_dataset.cpython-38.pyc differ
diff --git a/dataset/base_dataset.py b/dataset/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3005bfc7cbef54b20006ca88ee01783cec9425c3
--- /dev/null
+++ b/dataset/base_dataset.py
@@ -0,0 +1,220 @@
+import torch
+from PIL import Image, ImageDraw
+import torchvision.transforms as transforms
+import torchvision
+from zipfile import ZipFile
+import os
+import multiprocessing
+import math
+import numpy as np
+import random
+from io import BytesIO
+
+VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
+
+
+def check_filenames_in_zipdata(filenames, ziproot):
+ samples = []
+ for fst in ZipFile(ziproot).infolist():
+ fname = fst.filename
+ if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
+ continue
+ if os.path.splitext(fname)[1].lower() in VALID_IMAGE_TYPES:
+ samples.append((fname))
+ filenames = set(filenames)
+ samples = set(samples)
+ assert filenames.issubset(samples), 'Something wrong with your zip data'
+
+
+
+def draw_box(img, boxes):
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
+ draw = ImageDraw.Draw(img)
+ for bid, box in enumerate(boxes):
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline =colors[bid % len(colors)], width=4)
+ # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
+ return img
+
+
+
+def to_valid(x0, y0, x1, y1, image_size, min_box_size):
+ valid = True
+
+ if x0>image_size or y0>image_size or x1<0 or y1<0:
+ valid = False # no way to make this box vide, it is completely cropped out
+ return valid, (None, None, None, None)
+
+ x0 = max(x0, 0)
+ y0 = max(y0, 0)
+ x1 = min(x1, image_size)
+ y1 = min(y1, image_size)
+
+ if (x1-x0)*(y1-y0) / (image_size*image_size) < min_box_size:
+ valid = False
+ return valid, (None, None, None, None)
+
+ return valid, (x0, y0, x1, y1)
+
+
+
+
+
+def recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, image_size, min_box_size):
+ """
+ x,y,w,h: the original annotation corresponding to the raw image size.
+ trans_info: what resizing and cropping have been applied to the raw image
+ image_size: what is the final image size
+ """
+
+ x0 = x * trans_info["performed_scale"] - trans_info['crop_x']
+ y0 = y * trans_info["performed_scale"] - trans_info['crop_y']
+ x1 = (x + w) * trans_info["performed_scale"] - trans_info['crop_x']
+ y1 = (y + h) * trans_info["performed_scale"] - trans_info['crop_y']
+
+
+ # at this point, box annotation has been recalculated based on scaling and cropping
+ # but some point may fall off the image_size region (e.g., negative value), thus we
+ # need to clamp them into 0-image_size. But if all points falling outsize of image
+ # region, then we will consider this is an invalid box.
+ valid, (x0, y0, x1, y1) = to_valid(x0, y0, x1, y1, image_size, min_box_size)
+
+ if valid:
+ # we also perform random flip.
+ # Here boxes are valid, and are based on image_size
+ if trans_info["performed_flip"]:
+ x0, x1 = image_size-x1, image_size-x0
+
+ return valid, (x0, y0, x1, y1)
+
+
+
+class BaseDataset(torch.utils.data.Dataset):
+ def __init__(self, image_root, random_crop, random_flip, image_size):
+ super().__init__()
+ self.image_root = image_root
+ self.random_crop = random_crop
+ self.random_flip = random_flip
+ self.image_size = image_size
+ self.use_zip = False
+
+ if image_root[-4::] == 'zip':
+ self.use_zip = True
+ self.zip_dict = {}
+
+ if self.random_crop:
+ assert False, 'NOT IMPLEMENTED'
+
+
+ def fetch_zipfile(self, ziproot):
+ pid = multiprocessing.current_process().pid # get pid of this process.
+ if pid not in self.zip_dict:
+ self.zip_dict[pid] = ZipFile(ziproot)
+ zip_file = self.zip_dict[pid]
+ return zip_file
+
+ def fetch_image(self, filename):
+ if self.use_zip:
+ zip_file = self.fetch_zipfile(self.image_root)
+ image = Image.open( BytesIO(zip_file.read(filename)) ).convert('RGB')
+ return image
+ else:
+ image = Image.open( os.path.join(self.image_root,filename) ).convert('RGB')
+ return image
+
+
+ def vis_getitem_data(self, index=None, out=None, return_tensor=False, name="res.jpg", print_caption=True):
+
+ if out is None:
+ out = self[index]
+
+ img = torchvision.transforms.functional.to_pil_image( out["image"]*0.5+0.5 )
+ canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(out["image"]) )
+ W, H = img.size
+
+ if print_caption:
+ caption = out["caption"]
+ print(caption)
+ print(" ")
+
+ boxes = []
+ for box in out["boxes"]:
+ x0,y0,x1,y1 = box
+ boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] )
+ img = draw_box(img, boxes)
+
+ if return_tensor:
+ return torchvision.transforms.functional.to_tensor(img)
+ else:
+ img.save(name)
+
+
+ def transform_image(self, pil_image):
+ if self.random_crop:
+ assert False
+ arr = random_crop_arr(pil_image, self.image_size)
+ else:
+ arr, info = center_crop_arr(pil_image, self.image_size)
+
+ info["performed_flip"] = False
+ if self.random_flip and random.random()<0.5:
+ arr = arr[:, ::-1]
+ info["performed_flip"] = True
+
+ arr = arr.astype(np.float32) / 127.5 - 1
+ arr = np.transpose(arr, [2,0,1])
+
+ return torch.tensor(arr), info
+
+
+
+def center_crop_arr(pil_image, image_size):
+ # We are not on a new enough PIL to support the `reducing_gap`
+ # argument, which uses BOX downsampling at powers of two first.
+ # Thus, we do it by hand to improve downsample quality.
+ WW, HH = pil_image.size
+
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ # at this point, the min of pil_image side is desired image_size
+ performed_scale = image_size / min(WW, HH)
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+
+ info = {"performed_scale":performed_scale, 'crop_y':crop_y, 'crop_x':crop_x, "WW":WW, 'HH':HH}
+
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size], info
+
+
+def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
+
+ # We are not on a new enough PIL to support the `reducing_gap`
+ # argument, which uses BOX downsampling at powers of two first.
+ # Thus, we do it by hand to improve downsample quality.
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = smaller_dim_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
diff --git a/dataset/catalog.py b/dataset/catalog.py
new file mode 100644
index 0000000000000000000000000000000000000000..b622e477dae7cb4ba5c599fa7d2f7220b4311885
--- /dev/null
+++ b/dataset/catalog.py
@@ -0,0 +1,72 @@
+import os
+
+class DatasetCatalog:
+ def __init__(self, ROOT, which_embedder):
+ assert which_embedder in ['clip', 'bert']
+
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+
+ self.VGGrounding = {
+ "target": "dataset.tsv_dataset.TSVDataset",
+ "train_params": dict(
+ tsv_path=os.path.join(ROOT,'GROUNDING/gqa/tsv/train-00.tsv'),
+ )
+ }
+
+
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+
+ self.FlickrGrounding = {
+ "target": "dataset.tsv_dataset.TSVDataset",
+ "train_params":dict(
+ tsv_path=os.path.join(ROOT,'GROUNDING/flickr30k/tsv/train-00.tsv'),
+ )
+ }
+
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+ self.SBUGrounding = {
+ "target": "dataset.tsv_dataset.TSVDataset",
+ "train_params":dict(
+ tsv_path=os.path.join(ROOT,'GROUNDING/SBU/tsv/train-00.tsv'),
+ )
+ }
+
+
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+
+ self.CC3MGrounding = {
+ "target": "dataset.tsv_dataset.TSVDataset",
+ "train_params":dict(
+ tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'),
+ )
+ }
+
+
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+
+ self.CC12MGrounding = {
+ "target": "dataset.tsv_dataset.TSVDataset",
+ "train_params":dict(
+ tsv_path=os.path.join(ROOT,'GROUNDING/CC12M/tsv/train-00.tsv'),
+ )
+ }
+
+
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
+
+ # temp = 'category_embedding_clip.pth' if which_embedder == 'clip' else 'category_embedding_bert.pth'
+ # obj365_category_embedding_path = os.path.join(ROOT, 'OBJECTS365', temp)
+
+ self.Obj365Detection = {
+ "target": "dataset.tsv_dataset.TSVDataset",
+ "train_params":dict(
+ tsv_path=os.path.join(ROOT,'OBJECTS365/tsv/train-00.tsv'),
+ ),
+ }
+
+
diff --git a/dataset/cd_dataset.py b/dataset/cd_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0627329bda44a15c6821fc477bbde45acfe86a2f
--- /dev/null
+++ b/dataset/cd_dataset.py
@@ -0,0 +1,250 @@
+import json, os, random, math
+from collections import defaultdict
+from copy import deepcopy
+
+import torch
+from torch.utils.data import Dataset
+import torchvision.transforms as transforms
+
+import numpy as np
+from PIL import Image
+from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
+from io import BytesIO
+
+
+
+def not_in_at_all(list1, list2):
+ for a in list1:
+ if a in list2:
+ return False
+ return True
+
+
+def clean_annotations(annotations):
+ for anno in annotations:
+ anno.pop("segmentation", None)
+ anno.pop("area", None)
+ anno.pop("iscrowd", None)
+ # anno.pop("id", None)
+
+
+def make_a_sentence(obj_names, clean=False):
+
+ if clean:
+ obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
+
+ caption = ""
+ tokens_positive = []
+ for obj_name in obj_names:
+ start_len = len(caption)
+ caption += obj_name
+ end_len = len(caption)
+ caption += ", "
+ tokens_positive.append(
+ [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
+ )
+ caption = caption[:-2] # remove last ", "
+
+ return caption #, tokens_positive
+
+
+def check_all_have_same_images(instances_data, stuff_data, caption_data):
+ if stuff_data is not None:
+ assert instances_data["images"] == stuff_data["images"]
+ if caption_data is not None:
+ assert instances_data["images"] == caption_data["images"]
+
+
+class CDDataset(BaseDataset):
+ "CD: Caption Detection"
+ def __init__(self,
+ image_root,
+ category_embedding_path,
+ instances_json_path = None,
+ stuff_json_path = None,
+ caption_json_path = None,
+ prob_real_caption = 0,
+ fake_caption_type = 'empty',
+ image_size=256,
+ max_images=None,
+ min_box_size=0.01,
+ max_boxes_per_image=8,
+ include_other=False,
+ random_crop = False,
+ random_flip = True,
+ ):
+ super().__init__(random_crop, random_flip, image_size)
+
+ self.image_root = image_root
+ self.category_embedding_path = category_embedding_path
+ self.instances_json_path = instances_json_path
+ self.stuff_json_path = stuff_json_path
+ self.caption_json_path = caption_json_path
+ self.prob_real_caption = prob_real_caption
+ self.fake_caption_type = fake_caption_type
+ self.max_images = max_images
+ self.min_box_size = min_box_size
+ self.max_boxes_per_image = max_boxes_per_image
+ self.include_other = include_other
+
+
+ assert fake_caption_type in ["empty", "made"]
+ if prob_real_caption > 0:
+ assert caption_json_path is not None, "caption json must be given"
+
+
+ # Load all jsons
+ with open(instances_json_path, 'r') as f:
+ instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
+ clean_annotations(instances_data["annotations"])
+ self.instances_data = instances_data
+
+ self.stuff_data = None
+ if stuff_json_path is not None:
+ with open(stuff_json_path, 'r') as f:
+ stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
+ clean_annotations(stuff_data["annotations"])
+ self.stuff_data = stuff_data
+
+ self.captions_data = None
+ if caption_json_path is not None:
+ with open(caption_json_path, 'r') as f:
+ captions_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
+ clean_annotations(captions_data["annotations"])
+ self.captions_data = captions_data
+
+
+ # Load preprocessed name embedding
+ self.category_embeddings = torch.load(category_embedding_path)
+ self.embedding_len = list( self.category_embeddings.values() )[0].shape[0]
+
+
+ # Misc
+ self.image_ids = [] # main list for selecting images
+ self.image_id_to_filename = {} # file names used to read image
+ check_all_have_same_images(self.instances_data, self.stuff_data, self.captions_data)
+ for image_data in self.instances_data['images']:
+ image_id = image_data['id']
+ filename = image_data['file_name']
+ self.image_ids.append(image_id)
+ self.image_id_to_filename[image_id] = filename
+
+
+ # All category names (including things and stuff)
+ self.object_idx_to_name = {}
+ for category_data in self.instances_data['categories']:
+ self.object_idx_to_name[category_data['id']] = category_data['name']
+ if self.stuff_data is not None:
+ for category_data in self.stuff_data['categories']:
+ self.object_idx_to_name[category_data['id']] = category_data['name']
+
+
+ # Add object data from instances and stuff
+ self.image_id_to_objects = defaultdict(list)
+ self.select_objects( self.instances_data['annotations'] )
+ if self.stuff_data is not None:
+ self.select_objects( self.stuff_data['annotations'] )
+
+ # Add caption data
+ if self.captions_data is not None:
+ self.image_id_to_captions = defaultdict(list)
+ self.select_captions( self.captions_data['annotations'] )
+
+ # Check if all filenames can be found in the zip file
+ # all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids]
+ # check_filenames_in_zipdata(all_filenames, image_root)
+
+
+ def select_objects(self, annotations):
+ for object_anno in annotations:
+ image_id = object_anno['image_id']
+ object_name = self.object_idx_to_name[object_anno['category_id']]
+ other_ok = object_name != 'other' or self.include_other
+ if other_ok:
+ self.image_id_to_objects[image_id].append(object_anno)
+
+
+ def select_captions(self, annotations):
+ for caption_data in annotations:
+ image_id = caption_data['image_id']
+ self.image_id_to_captions[image_id].append(caption_data)
+
+
+ def total_images(self):
+ return len(self)
+
+
+ def __getitem__(self, index):
+ if self.max_boxes_per_image > 99:
+ assert False, "Are you sure setting such large number of boxes?"
+
+ out = {}
+
+ image_id = self.image_ids[index]
+ out['id'] = image_id
+
+ # Image
+ filename = self.image_id_to_filename[image_id]
+ image = self.fetch_image(filename)
+ #WW, HH = image.size
+ image_tensor, trans_info = self.transform_image(image)
+ out["image"] = image_tensor
+
+
+ # Select valid boxes after cropping (center or random)
+ this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id])
+ areas = []
+ all_obj_names = []
+ all_boxes = []
+ all_masks = []
+ all_positive_embeddings = []
+ for object_anno in this_image_obj_annos:
+
+ x, y, w, h = object_anno['bbox']
+ valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
+
+ if valid:
+ areas.append( (x1-x0)*(y1-y0) )
+ obj_name = self.object_idx_to_name[ object_anno['category_id'] ]
+ all_obj_names.append(obj_name)
+ all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
+ all_masks.append(1)
+ all_positive_embeddings.append( self.category_embeddings[obj_name] )
+
+ wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
+ wanted_idxs = wanted_idxs[0:self.max_boxes_per_image]
+ obj_names = [] # used for making a sentence
+ boxes = torch.zeros(self.max_boxes_per_image, 4)
+ masks = torch.zeros(self.max_boxes_per_image)
+ positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len)
+ for i, idx in enumerate(wanted_idxs):
+ obj_names.append( all_obj_names[idx] )
+ boxes[i] = all_boxes[idx]
+ masks[i] = all_masks[idx]
+ positive_embeddings[i] = all_positive_embeddings[idx]
+
+ # Caption
+ if random.uniform(0, 1) < self.prob_real_caption:
+ caption_data = self.image_id_to_captions[image_id]
+ idx = random.randint(0, len(caption_data)-1 )
+ caption = caption_data[idx]["caption"]
+ else:
+ if self.fake_caption_type == "empty":
+ caption = ""
+ else:
+ caption = make_a_sentence(obj_names, clean=True)
+
+
+ out["caption"] = caption
+ out["boxes"] = boxes
+ out["masks"] = masks
+ out["positive_embeddings"] = positive_embeddings
+
+ return out
+
+
+ def __len__(self):
+ if self.max_images is None:
+ return len(self.image_ids)
+ return min(len(self.image_ids), self.max_images)
+
diff --git a/dataset/concat_dataset.py b/dataset/concat_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..df637663567a8c74673de9361950a6d663357fa0
--- /dev/null
+++ b/dataset/concat_dataset.py
@@ -0,0 +1,65 @@
+from .catalog import DatasetCatalog
+from ldm.util import instantiate_from_config
+import torch
+
+
+
+
+class ConCatDataset():
+ def __init__(self, dataset_name_list, ROOT, which_embedder, train=True, repeats=None):
+ self.datasets = []
+ cul_previous_dataset_length = 0
+ offset_map = []
+ which_dataset = []
+
+ if repeats is None:
+ repeats = [1] * len(dataset_name_list)
+ else:
+ assert len(repeats) == len(dataset_name_list)
+
+
+ Catalog = DatasetCatalog(ROOT, which_embedder)
+ for dataset_idx, (dataset_name, yaml_params) in enumerate(dataset_name_list.items()):
+ repeat = repeats[dataset_idx]
+
+ dataset_dict = getattr(Catalog, dataset_name)
+
+ target = dataset_dict['target']
+ params = dataset_dict['train_params'] if train else dataset_dict['val_params']
+ if yaml_params is not None:
+ params.update(yaml_params)
+ dataset = instantiate_from_config( dict(target=target, params=params) )
+
+ self.datasets.append(dataset)
+ for _ in range(repeat):
+ offset_map.append( torch.ones(len(dataset))*cul_previous_dataset_length )
+ which_dataset.append( torch.ones(len(dataset))*dataset_idx )
+ cul_previous_dataset_length += len(dataset)
+ offset_map = torch.cat(offset_map, dim=0).long()
+ self.total_length = cul_previous_dataset_length
+
+ self.mapping = torch.arange(self.total_length) - offset_map
+ self.which_dataset = torch.cat(which_dataset, dim=0).long()
+
+
+ def total_images(self):
+ count = 0
+ for dataset in self.datasets:
+ print(dataset.total_images())
+ count += dataset.total_images()
+ return count
+
+
+
+ def __getitem__(self, idx):
+ dataset = self.datasets[ self.which_dataset[idx] ]
+ return dataset[ self.mapping[idx] ]
+
+
+ def __len__(self):
+ return self.total_length
+
+
+
+
+
diff --git a/dataset/grounding_dataset.py b/dataset/grounding_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b1fa74fc948466bd3d1a522025413ee5224577a
--- /dev/null
+++ b/dataset/grounding_dataset.py
@@ -0,0 +1,205 @@
+from tkinter.messagebox import NO
+import torch
+import json
+from collections import defaultdict
+from PIL import Image, ImageDraw
+from copy import deepcopy
+import os
+import torchvision.transforms as transforms
+import torchvision
+from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
+from io import BytesIO
+import random
+
+def check_unique(images, fields):
+ for field in fields:
+ temp_list = []
+ for img_info in images:
+ temp_list.append(img_info[field])
+ assert len(set(temp_list)) == len(temp_list), field
+
+def clean_data(data):
+ for data_info in data:
+ data_info.pop("original_img_id", None)
+ data_info.pop("original_id", None)
+ data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image)
+ data_info.pop("dataset_name", None)
+ data_info.pop("data_source", None)
+ data_info["data_id"] = data_info.pop("id")
+
+
+def clean_annotations(annotations):
+ for anno_info in annotations:
+ anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco
+ anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation
+ anno_info.pop("area", None)
+ # anno_info.pop("id", None)
+ anno_info["data_id"] = anno_info.pop("image_id")
+
+
+def draw_box(img, boxes):
+ draw = ImageDraw.Draw(img)
+ for box in boxes:
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
+ return img
+
+
+def xyhw2xyxy(box):
+ x0, y0, w, h = box
+ return [ x0, y0, x0+w, y0+h ]
+
+
+
+class GroundingDataset(BaseDataset):
+ def __init__(self,
+ image_root,
+ json_path,
+ annotation_embedding_path,
+ prob_real_caption=1,
+ image_size=256,
+ min_box_size=0.01,
+ max_boxes_per_data=8,
+ max_images=None, # set as 30K used to eval
+ random_crop = False,
+ random_flip = True,
+ ):
+ super().__init__(image_root, random_crop, random_flip, image_size)
+ self.image_root = image_root
+ self.json_path = json_path
+ self.annotation_embedding_path = annotation_embedding_path
+ self.prob_real_caption = prob_real_caption
+ self.min_box_size = min_box_size
+ self.max_boxes_per_data = max_boxes_per_data
+ self.max_images = max_images
+
+
+ # Load raw data
+ with open(json_path, 'r') as f:
+ json_raw = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
+ self.data = json_raw["images"] # donot name it images, which is misleading
+ self.annotations = json_raw["annotations"]
+
+
+ # Load preprocessed name embedding
+ if 'bert' in annotation_embedding_path:
+ self.embedding_len = 1280
+ elif 'clip' in annotation_embedding_path:
+ self.embedding_len = 768
+ else:
+ assert False
+
+
+ # clean data and annotation
+ check_unique( self.data, ['id'] )
+ check_unique( self.annotations, ['id'] )
+ clean_data(self.data)
+ clean_annotations(self.annotations)
+ self.data_id_list = [ datum['data_id'] for datum in self.data ]
+ self.data = { datum['data_id']:datum for datum in self.data } # map self.data from a list into a dict
+
+
+ # data point to its annotation mapping
+ self.data_id_to_annos = defaultdict(list)
+ for anno in self.annotations:
+ self.data_id_to_annos[ anno["data_id"] ].append(anno)
+
+
+
+ # These are not used that offen, but are useful in some cases
+ self.file_names = [] # all training images
+ self.file_name_to_data_ids = defaultdict(list) # for each image, there are multiple data points (captions)
+ for data_id in self.data_id_list:
+ fine_name = self.data[data_id]["file_name"]
+ self.file_names.append(fine_name)
+ self.file_name_to_data_ids[fine_name].append(data_id)
+ self.file_names = list(set(self.file_names))
+
+
+ if self.max_images is not None:
+ "This is only used as COCO2017P evulation, when we set max_images as 30k"
+ assert False, 'I have commented out the following code to save cpu memory'
+ # new_data_id_list = []
+ # new_file_name_to_data_ids = defaultdict(list)
+ # self.file_names = self.file_names[0:self.max_images]
+ # for file_name in self.file_names:
+ # data_id = self.file_name_to_data_ids[file_name][0]
+ # new_data_id_list.append(data_id)
+ # new_file_name_to_data_ids[file_name].append(data_id)
+ # self.data_id_list = new_data_id_list
+ # self.file_name_to_data_ids = new_file_name_to_data_ids
+
+
+ # Check if all filenames can be found in the zip file
+ # all_filenames = [self.data[idx]['file_name'] for idx in self.data_id_list ]
+ # check_filenames_in_zipdata(all_filenames, image_root)
+
+
+ def total_images(self):
+ return len(self.file_names)
+
+
+ def __getitem__(self, index):
+ if self.max_boxes_per_data > 99:
+ assert False, "Are you sure setting such large number of boxes?"
+
+ out = {}
+
+ data_id = self.data_id_list[index]
+ out['id'] = data_id
+
+
+ # Image and caption
+ file_name = self.data[data_id]['file_name']
+ image = self.fetch_image(file_name)
+ image_tensor, trans_info = self.transform_image(image)
+ out["image"] = image_tensor
+
+ if random.uniform(0, 1) < self.prob_real_caption:
+ out["caption"] = self.data[data_id]["caption"]
+ else:
+ out["caption"] = ""
+
+
+
+ annos = deepcopy(self.data_id_to_annos[data_id])
+ areas = []
+ all_boxes = []
+ all_masks = []
+ all_positive_embeddings = []
+
+
+ for anno in annos:
+
+ x, y, w, h = anno['bbox']
+ valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
+
+ if valid:
+ areas.append( (x1-x0)*(y1-y0) )
+ all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
+ all_masks.append(1)
+ all_positive_embeddings.append( torch.load(os.path.join(self.annotation_embedding_path,str(anno["id"])), map_location='cpu' ) )
+
+ wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
+ wanted_idxs = wanted_idxs[0:self.max_boxes_per_data]
+
+ boxes = torch.zeros(self.max_boxes_per_data, 4)
+ masks = torch.zeros(self.max_boxes_per_data)
+ positive_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
+ for i, idx in enumerate(wanted_idxs):
+ boxes[i] = all_boxes[idx]
+ masks[i] = all_masks[idx]
+ positive_embeddings[i] = all_positive_embeddings[idx]
+
+
+ out["boxes"] = boxes
+ out["masks"] = masks
+ out["positive_embeddings"] = positive_embeddings
+
+ return out
+
+
+
+ def __len__(self):
+ return len(self.data_id_list)
+
+
diff --git a/dataset/layout_dataset.py b/dataset/layout_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d2b4dc73e8c194e92725faeab368f0951f6f7e8
--- /dev/null
+++ b/dataset/layout_dataset.py
@@ -0,0 +1,237 @@
+import json, os, random, math
+from collections import defaultdict
+from copy import deepcopy
+
+import torch
+from torch.utils.data import Dataset
+import torchvision.transforms as transforms
+
+import numpy as np
+from PIL import Image, ImageOps
+from .base_dataset import BaseDataset, check_filenames_in_zipdata
+from io import BytesIO
+
+
+
+
+def clean_annotations(annotations):
+ for anno in annotations:
+ anno.pop("segmentation", None)
+ anno.pop("area", None)
+ anno.pop("iscrowd", None)
+ anno.pop("id", None)
+
+
+def make_a_sentence(obj_names, clean=False):
+
+ if clean:
+ obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
+
+ caption = ""
+ tokens_positive = []
+ for obj_name in obj_names:
+ start_len = len(caption)
+ caption += obj_name
+ end_len = len(caption)
+ caption += ", "
+ tokens_positive.append(
+ [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
+ )
+ caption = caption[:-2] # remove last ", "
+
+ return caption #, tokens_positive
+
+
+class LayoutDataset(BaseDataset):
+ """
+ Note: this dataset can somehow be achieved in cd_dataset.CDDataset
+ Since if you donot set prob_real_caption=0 in CDDataset, then that
+ dataset will only use detection annotations. However, in that dataset,
+ we do not remove images but remove boxes.
+
+ However, in layout2img works, people will just resize raw image data into 256*256,
+ thus they pre-calculate box size and apply min_box_size before min/max_boxes_per_image.
+ And then they will remove images if does not follow the rule.
+
+ These two different methods will lead to different number of training/val images.
+ Thus this dataset here is only for layout2img.
+
+ """
+ def __init__(self,
+ image_root,
+ instances_json_path,
+ stuff_json_path,
+ category_embedding_path,
+ fake_caption_type = 'empty',
+ image_size=256,
+ max_samples=None,
+ min_box_size=0.02,
+ min_boxes_per_image=3,
+ max_boxes_per_image=8,
+ include_other=False,
+ random_flip=True
+ ):
+ super().__init__(random_crop=None, random_flip=None, image_size=None) # we only use vis_getitem func in BaseDataset, donot use the others.
+
+ assert fake_caption_type in ['empty', 'made']
+ self.image_root = image_root
+ self.instances_json_path = instances_json_path
+ self.stuff_json_path = stuff_json_path
+ self.category_embedding_path = category_embedding_path
+ self.fake_caption_type = fake_caption_type
+ self.image_size = image_size
+ self.max_samples = max_samples
+ self.min_box_size = min_box_size
+ self.min_boxes_per_image = min_boxes_per_image
+ self.max_boxes_per_image = max_boxes_per_image
+ self.include_other = include_other
+ self.random_flip = random_flip
+
+
+ self.transform = transforms.Compose([transforms.Resize( (image_size, image_size) ),
+ transforms.ToTensor(),
+ transforms.Lambda(lambda t: (t * 2) - 1) ])
+
+ # Load all jsons
+ with open(instances_json_path, 'r') as f:
+ instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
+ clean_annotations(instances_data["annotations"])
+ self.instances_data = instances_data
+
+ with open(stuff_json_path, 'r') as f:
+ stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
+ clean_annotations(stuff_data["annotations"])
+ self.stuff_data = stuff_data
+
+
+ # Load preprocessed name embedding
+ self.category_embeddings = torch.load(category_embedding_path)
+ self.embedding_len = list( self.category_embeddings.values() )[0].shape[0]
+
+
+ # Misc
+ self.image_ids = [] # main list for selecting images
+ self.image_id_to_filename = {} # file names used to read image
+ self.image_id_to_size = {} # original size of this image
+ assert instances_data['images'] == stuff_data["images"]
+ for image_data in instances_data['images']:
+ image_id = image_data['id']
+ filename = image_data['file_name']
+ width = image_data['width']
+ height = image_data['height']
+ self.image_ids.append(image_id)
+ self.image_id_to_filename[image_id] = filename
+ self.image_id_to_size[image_id] = (width, height)
+
+ # All category names (including things and stuff)
+ self.things_id_list = []
+ self.stuff_id_list = []
+ self.object_idx_to_name = {}
+ for category_data in instances_data['categories']:
+ self.things_id_list.append( category_data['id'] )
+ self.object_idx_to_name[category_data['id']] = category_data['name']
+ for category_data in stuff_data['categories']:
+ self.stuff_id_list.append( category_data['id'] )
+ self.object_idx_to_name[category_data['id']] = category_data['name']
+ self.all_categories = [ self.object_idx_to_name.get(k, None) for k in range(183+1) ]
+
+
+ # Add object data from instances and stuff
+ self.image_id_to_objects = defaultdict(list)
+ self.select_objects( instances_data['annotations'] )
+ self.select_objects( stuff_data['annotations'] )
+
+
+ # Prune images that have too few or too many objects
+ new_image_ids = []
+ for image_id in self.image_ids:
+ num_objs = len(self.image_id_to_objects[image_id])
+ if self.min_boxes_per_image <= num_objs <= self.max_boxes_per_image:
+ new_image_ids.append(image_id)
+ self.image_ids = new_image_ids
+
+
+ # Check if all filenames can be found in the zip file
+ all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids]
+ check_filenames_in_zipdata(all_filenames, image_root)
+
+
+
+ def select_objects(self, annotations):
+ for object_anno in annotations:
+ image_id = object_anno['image_id']
+ _, _, w, h = object_anno['bbox']
+ W, H = self.image_id_to_size[image_id]
+ box_area = (w * h) / (W * H)
+ box_ok = box_area > self.min_box_size
+ object_name = self.object_idx_to_name[object_anno['category_id']]
+ other_ok = object_name != 'other' or self.include_other
+ if box_ok and other_ok:
+ self.image_id_to_objects[image_id].append(object_anno)
+
+
+ def total_images(self):
+ return len(self)
+
+
+ def __getitem__(self, index):
+ if self.max_boxes_per_image > 99:
+ assert False, "Are you sure setting such large number of boxes?"
+
+ out = {}
+
+ image_id = self.image_ids[index]
+ out['id'] = image_id
+
+ flip = self.random_flip and random.random()<0.5
+
+ # Image
+ filename = self.image_id_to_filename[image_id]
+ zip_file = self.fetch_zipfile(self.image_root)
+ image = Image.open(BytesIO(zip_file.read(filename))).convert('RGB')
+ WW, HH = image.size
+ if flip:
+ image = ImageOps.mirror(image)
+ out["image"] = self.transform(image)
+
+ this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id])
+
+ # Make a sentence
+ obj_names = [] # used for make a sentence
+ boxes = torch.zeros(self.max_boxes_per_image, 4)
+ masks = torch.zeros(self.max_boxes_per_image)
+ positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len)
+ for idx, object_anno in enumerate(this_image_obj_annos):
+ obj_name = self.object_idx_to_name[ object_anno['category_id'] ]
+ obj_names.append(obj_name)
+ x, y, w, h = object_anno['bbox']
+ x0 = x / WW
+ y0 = y / HH
+ x1 = (x + w) / WW
+ y1 = (y + h) / HH
+ if flip:
+ x0, x1 = 1-x1, 1-x0
+ boxes[idx] = torch.tensor([x0,y0,x1,y1])
+ masks[idx] = 1
+ positive_embeddings[idx] = self.category_embeddings[obj_name]
+
+ if self.fake_caption_type == 'empty':
+ caption = ""
+ else:
+ caption = make_a_sentence(obj_names, clean=True)
+
+ out["caption"] = caption
+ out["boxes"] = boxes
+ out["masks"] = masks
+ out["positive_embeddings"] = positive_embeddings
+
+
+ return out
+
+
+ def __len__(self):
+ if self.max_samples is None:
+ return len(self.image_ids)
+ return min(len(self.image_ids), self.max_samples)
+
+
diff --git a/dataset/tsv.py b/dataset/tsv.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfbc4c4d0c1ac4b833b8229a952c1e2fe03bd6f1
--- /dev/null
+++ b/dataset/tsv.py
@@ -0,0 +1,212 @@
+import os
+import os.path as op
+import gc
+import json
+from typing import List
+import logging
+
+try:
+ from .blob_storage import BlobStorage, disk_usage
+except:
+ class BlobStorage:
+ pass
+
+
+def generate_lineidx(filein: str, idxout: str) -> None:
+ idxout_tmp = idxout + '.tmp'
+ with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
+ fsize = os.fstat(tsvin.fileno()).st_size
+ fpos = 0
+ while fpos != fsize:
+ tsvout.write(str(fpos) + "\n")
+ tsvin.readline()
+ fpos = tsvin.tell()
+ os.rename(idxout_tmp, idxout)
+
+
+def read_to_character(fp, c):
+ result = []
+ while True:
+ s = fp.read(32)
+ assert s != ''
+ if c in s:
+ result.append(s[: s.index(c)])
+ break
+ else:
+ result.append(s)
+ return ''.join(result)
+
+
+class TSVFile(object):
+ def __init__(self,
+ tsv_file: str,
+ if_generate_lineidx: bool = False,
+ lineidx: str = None,
+ class_selector: List[str] = None,
+ blob_storage: BlobStorage = None):
+ self.tsv_file = tsv_file
+ self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \
+ if not lineidx else lineidx
+ self.linelist = op.splitext(tsv_file)[0] + '.linelist'
+ self.chunks = op.splitext(tsv_file)[0] + '.chunks'
+ self._fp = None
+ self._lineidx = None
+ self._sample_indices = None
+ self._class_boundaries = None
+ self._class_selector = class_selector
+ self._blob_storage = blob_storage
+ self._len = None
+ # the process always keeps the process which opens the file.
+ # If the pid is not equal to the currrent pid, we will re-open the file.
+ self.pid = None
+ # generate lineidx if not exist
+ if not op.isfile(self.lineidx) and if_generate_lineidx:
+ generate_lineidx(self.tsv_file, self.lineidx)
+
+ def __del__(self):
+ self.gcidx()
+ if self._fp:
+ self._fp.close()
+ # physically remove the tsv file if it is retrieved by BlobStorage
+ if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file):
+ try:
+ original_usage = disk_usage('/')
+ os.remove(self.tsv_file)
+ logging.info("Purged %s (disk usage: %.2f%% => %.2f%%)" %
+ (self.tsv_file, original_usage, disk_usage('/') * 100))
+ except:
+ # Known issue: multiple threads attempting to delete the file will raise a FileNotFound error.
+ # TODO: try Threadling.Lock to better handle the race condition
+ pass
+
+ def __str__(self):
+ return "TSVFile(tsv_file='{}')".format(self.tsv_file)
+
+ def __repr__(self):
+ return str(self)
+
+ def gcidx(self):
+ logging.debug('Run gc collect')
+ self._lineidx = None
+ self._sample_indices = None
+ #self._class_boundaries = None
+ return gc.collect()
+
+ def get_class_boundaries(self):
+ return self._class_boundaries
+
+ def num_rows(self, gcf=False):
+ if (self._len is None):
+ self._ensure_lineidx_loaded()
+ retval = len(self._sample_indices)
+
+ if (gcf):
+ self.gcidx()
+
+ self._len = retval
+
+ return self._len
+
+ def seek(self, idx: int):
+ self._ensure_tsv_opened()
+ self._ensure_lineidx_loaded()
+ try:
+ pos = self._lineidx[self._sample_indices[idx]]
+ except:
+ logging.info('=> {}-{}'.format(self.tsv_file, idx))
+ raise
+ self._fp.seek(pos)
+ return [s.strip() for s in self._fp.readline().split('\t')]
+
+ def seek_first_column(self, idx: int):
+ self._ensure_tsv_opened()
+ self._ensure_lineidx_loaded()
+ pos = self._lineidx[idx]
+ self._fp.seek(pos)
+ return read_to_character(self._fp, '\t')
+
+ def get_key(self, idx: int):
+ return self.seek_first_column(idx)
+
+ def __getitem__(self, index: int):
+ return self.seek(index)
+
+ def __len__(self):
+ return self.num_rows()
+
+ def _ensure_lineidx_loaded(self):
+ if self._lineidx is None:
+ logging.debug('=> loading lineidx: {}'.format(self.lineidx))
+ with open(self.lineidx, 'r') as fp:
+ lines = fp.readlines()
+ lines = [line.strip() for line in lines]
+ self._lineidx = [int(line) for line in lines]
+
+ # read the line list if exists
+ linelist = None
+ if op.isfile(self.linelist):
+ with open(self.linelist, 'r') as fp:
+ linelist = sorted(
+ [
+ int(line.strip())
+ for line in fp.readlines()
+ ]
+ )
+
+ if op.isfile(self.chunks):
+ self._sample_indices = []
+ self._class_boundaries = []
+ class_boundaries = json.load(open(self.chunks, 'r'))
+ for class_name, boundary in class_boundaries.items():
+ start = len(self._sample_indices)
+ if class_name in self._class_selector:
+ for idx in range(boundary[0], boundary[1] + 1):
+ # NOTE: potentially slow when linelist is long, try to speed it up
+ if linelist and idx not in linelist:
+ continue
+ self._sample_indices.append(idx)
+ end = len(self._sample_indices)
+ self._class_boundaries.append((start, end))
+ else:
+ if linelist:
+ self._sample_indices = linelist
+ else:
+ self._sample_indices = list(range(len(self._lineidx)))
+
+ def _ensure_tsv_opened(self):
+ if self._fp is None:
+ if self._blob_storage:
+ self._fp = self._blob_storage.open(self.tsv_file)
+ else:
+ self._fp = open(self.tsv_file, 'r')
+ self.pid = os.getpid()
+
+ if self.pid != os.getpid():
+ logging.debug('=> re-open {} because the process id changed'.format(self.tsv_file))
+ self._fp = open(self.tsv_file, 'r')
+ self.pid = os.getpid()
+
+
+class TSVWriter(object):
+ def __init__(self, tsv_file):
+ self.tsv_file = tsv_file
+ self.lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
+ self.tsv_file_tmp = self.tsv_file + '.tmp'
+ self.lineidx_file_tmp = self.lineidx_file + '.tmp'
+
+ self.tsv_fp = open(self.tsv_file_tmp, 'w')
+ self.lineidx_fp = open(self.lineidx_file_tmp, 'w')
+
+ self.idx = 0
+
+ def write(self, values, sep='\t'):
+ v = '{0}\n'.format(sep.join(map(str, values)))
+ self.tsv_fp.write(v)
+ self.lineidx_fp.write(str(self.idx) + '\n')
+ self.idx = self.idx + len(v)
+
+ def close(self):
+ self.tsv_fp.close()
+ self.lineidx_fp.close()
+ os.rename(self.tsv_file_tmp, self.tsv_file)
+ os.rename(self.lineidx_file_tmp, self.lineidx_file)
diff --git a/dataset/tsv_dataset.py b/dataset/tsv_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc2db59faf1254970b35d2fc8dec78afde4f6918
--- /dev/null
+++ b/dataset/tsv_dataset.py
@@ -0,0 +1,326 @@
+from tkinter.messagebox import NO
+import torch
+import json
+from collections import defaultdict
+from PIL import Image, ImageDraw
+from copy import deepcopy
+import os
+import torchvision.transforms as transforms
+import torchvision
+from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
+from io import BytesIO
+import random
+
+from .tsv import TSVFile
+
+from io import BytesIO
+import base64
+from PIL import Image
+import numpy as np
+
+
+def decode_base64_to_pillow(image_b64):
+ return Image.open(BytesIO(base64.b64decode(image_b64))).convert('RGB')
+
+def decode_tensor_from_string(arr_str, use_tensor=True):
+ arr = np.frombuffer(base64.b64decode(arr_str), dtype='float32')
+ if use_tensor:
+ arr = torch.from_numpy(arr)
+ return arr
+
+def decode_item(item):
+ item = json.loads(item)
+ item['image'] = decode_base64_to_pillow(item['image'])
+
+ for anno in item['annos']:
+ anno['image_embedding_before'] = decode_tensor_from_string(anno['image_embedding_before'])
+ anno['text_embedding_before'] = decode_tensor_from_string(anno['text_embedding_before'])
+ anno['image_embedding_after'] = decode_tensor_from_string(anno['image_embedding_after'])
+ anno['text_embedding_after'] = decode_tensor_from_string(anno['text_embedding_after'])
+ return item
+
+def check_unique(images, fields):
+ for field in fields:
+ temp_list = []
+ for img_info in images:
+ temp_list.append(img_info[field])
+ assert len(set(temp_list)) == len(temp_list), field
+
+def clean_data(data):
+ for data_info in data:
+ data_info.pop("original_img_id", None)
+ data_info.pop("original_id", None)
+ data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image)
+ data_info.pop("dataset_name", None)
+ data_info.pop("data_source", None)
+ data_info["data_id"] = data_info.pop("id")
+
+
+def clean_annotations(annotations):
+ for anno_info in annotations:
+ anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco
+ anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation
+ anno_info.pop("area", None)
+ # anno_info.pop("id", None)
+ anno_info["data_id"] = anno_info.pop("image_id")
+
+
+def draw_box(img, boxes):
+ draw = ImageDraw.Draw(img)
+ for box in boxes:
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
+ return img
+
+
+def xyhw2xyxy(box):
+ x0, y0, w, h = box
+ return [ x0, y0, x0+w, y0+h ]
+
+
+def make_a_sentence(obj_names, clean=False):
+
+ if clean:
+ obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
+
+ caption = ""
+ tokens_positive = []
+ for obj_name in obj_names:
+ start_len = len(caption)
+ caption += obj_name
+ end_len = len(caption)
+ caption += ", "
+ tokens_positive.append(
+ [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
+ )
+ caption = caption[:-2] # remove last ", "
+
+ return caption #, tokens_positive
+
+
+def mask_for_random_drop_text_or_image_feature(masks, random_drop_embedding):
+ """
+ input masks tell how many valid grounding tokens for this image
+ e.g., 1,1,1,1,0,0,0,0,0,0...
+
+ If random_drop_embedding=both. we will random drop either image or
+ text feature for each token,
+ but we always make sure there is at least one feature used.
+ In other words, the following masks are not valid
+ (because for the second obj, no feature at all):
+ image: 1,0,1,1,0,0,0,0,0
+ text: 1,0,0,0,0,0,0,0,0
+
+ if random_drop_embedding=image. we will random drop image feature
+ and always keep the text one.
+
+ """
+ N = masks.shape[0]
+
+ if random_drop_embedding=='both':
+ temp_mask = torch.ones(2,N)
+ for i in range(N):
+ if random.uniform(0, 1) < 0.5: # else keep both features
+ idx = random.sample([0,1], 1)[0] # randomly choose to drop image or text feature
+ temp_mask[idx,i] = 0
+ image_masks = temp_mask[0]*masks
+ text_masks = temp_mask[1]*masks
+
+ if random_drop_embedding=='image':
+ image_masks = masks*(torch.rand(N)>0.5)*1
+ text_masks = masks
+
+ return image_masks, text_masks
+
+
+
+
+
+def project(x, projection_matrix):
+ """
+ x (Batch*768) should be the penultimate feature of CLIP (before projection)
+ projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer
+ defined in CLIP (out_dim, in_dim), thus we need to apply transpose below.
+ this function will return the CLIP feature (without normalziation)
+ """
+ return x@torch.transpose(projection_matrix, 0, 1)
+
+
+def inv_project(y, projection_matrix):
+ """
+ y (Batch*768) should be the CLIP feature (after projection)
+ projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer
+ defined in CLIP (out_dim, in_dim).
+ this function will return the CLIP penultimate feature.
+
+ Note: to make sure getting the correct penultimate feature, the input y should not be normalized.
+ If it is normalized, then the result will be scaled by CLIP feature norm, which is unknown.
+ """
+ return y@torch.transpose(torch.linalg.inv(projection_matrix), 0, 1)
+
+
+
+
+class TSVDataset(BaseDataset):
+ def __init__(self,
+ tsv_path,
+ which_embedder='clip',
+ which_layer=['after','after'], # text and image
+ prob_use_caption=1,
+ random_drop_embedding='none',
+ image_size=256,
+ min_box_size=0.01,
+ max_boxes_per_data=8,
+ max_images=None, # set as 30K used to eval
+ random_crop = False,
+ random_flip = True,
+ ):
+ image_root = "a placeholder path as we are using tsv here"
+ super().__init__(image_root, random_crop, random_flip, image_size)
+ self.tsv_path = tsv_path
+ self.which_embedder = which_embedder
+ self.prob_use_caption = prob_use_caption
+ self.random_drop_embedding = random_drop_embedding
+ self.min_box_size = min_box_size
+ self.max_boxes_per_data = max_boxes_per_data
+ self.max_images = max_images
+
+ assert which_layer in [ ['after','after'], ['before','after_renorm'], ['before','after_reproject'] ]
+ assert random_drop_embedding in ['none', 'both', 'image']
+ self.which_layer_text = which_layer[0]
+ self.which_layer_image = which_layer[1]
+
+ #self.projection_matrix = torch.load(os.path.join(os.path.dirname(__file__), 'projection_matrix') )
+ self.projection_matrix = torch.load('projection_matrix.pth')
+
+ # Load tsv data
+ self.tsv_file = TSVFile(self.tsv_path)
+
+
+ # Load preprocessed name embedding
+ if which_embedder == 'bert':
+ self.embedding_len = 1280
+ elif which_embedder == 'clip':
+ self.embedding_len = 768
+ else:
+ assert False
+
+ def total_images(self):
+ return len(self)
+
+ def get_item_from_tsv(self, index):
+ _, item = self.tsv_file[index]
+ item = decode_item(item)
+ return item
+
+
+ def mapping(self, image_embedding):
+ if self.which_layer_image == 'after':
+ # both use CLIP aligned feature
+ return image_embedding
+ elif self.which_layer_image == 'after_renorm':
+ # text use before, but image use after projection but normalize to 28.7
+ return image_embedding*28.7
+ elif self.which_layer_image == 'after_reproject':
+ image_embedding = project( image_embedding.unsqueeze(0), self.projection_matrix.T )
+ image_embedding = image_embedding.squeeze(0)
+ image_embedding = image_embedding / image_embedding.norm()
+ image_embedding = image_embedding * 28.7
+ return image_embedding
+
+
+
+ def __getitem__(self, index):
+ if self.max_boxes_per_data > 99:
+ assert False, "Are you sure setting such large number of boxes?"
+
+ raw_item = self.get_item_from_tsv(index)
+ is_det = raw_item.get('is_det', False) # if it is from detection (such as o365), then we will make a caption
+
+ out = {}
+
+ # -------------------- id and image ------------------- #
+ out['id'] = raw_item['data_id']
+ image = raw_item['image']
+ image_tensor, trans_info = self.transform_image(image)
+ out["image"] = image_tensor
+
+
+
+ # -------------------- grounding token ------------------- #
+ annos = raw_item['annos']
+
+ areas = []
+ all_boxes = []
+ all_masks = []
+ all_text_embeddings = []
+ all_image_embeddings = []
+ if is_det:
+ all_category_names = []
+
+ text_embedding_name = 'text_embedding_before' if self.which_layer_text == 'before' else 'text_embedding_after'
+ image_embedding_name = 'image_embedding_after'
+
+ for anno in annos:
+ x, y, w, h = anno['bbox']
+ valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size)
+
+ if valid:
+ areas.append( (x1-x0)*(y1-y0) )
+ all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1
+ all_masks.append(1)
+ all_text_embeddings.append(anno[text_embedding_name])
+ all_image_embeddings.append( self.mapping(anno[image_embedding_name]) )
+ if is_det:
+ all_category_names.append(anno["category_name"])
+
+
+ wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
+ wanted_idxs = wanted_idxs[0:self.max_boxes_per_data]
+
+ boxes = torch.zeros(self.max_boxes_per_data, 4)
+ masks = torch.zeros(self.max_boxes_per_data)
+ text_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
+ image_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len)
+ if is_det:
+ category_names = []
+ for i, idx in enumerate(wanted_idxs):
+ boxes[i] = all_boxes[idx]
+ masks[i] = all_masks[idx]
+ text_embeddings[i] = all_text_embeddings[idx]
+ image_embeddings[i] = all_image_embeddings[idx]
+ if is_det:
+ category_names.append(all_category_names[idx])
+
+ if self.random_drop_embedding != 'none':
+ image_masks, text_masks = mask_for_random_drop_text_or_image_feature(masks, self.random_drop_embedding)
+ else:
+ image_masks = masks
+ text_masks = masks
+
+
+ out["boxes"] = boxes
+ out["masks"] = masks
+ out["image_masks"] = image_masks
+ out["text_masks"] = text_masks
+ out["text_embeddings"] = text_embeddings
+ out["image_embeddings"] = image_embeddings
+
+
+
+ # -------------------- caption ------------------- #
+ if random.uniform(0, 1) < self.prob_use_caption:
+ if is_det:
+ out["caption"] = make_a_sentence(category_names)
+ else:
+ out["caption"] = raw_item["caption"]
+ else:
+ out["caption"] = ""
+
+ return out
+
+
+
+ def __len__(self):
+ return len(self.tsv_file)
+
+
diff --git a/dataset/utils.py b/dataset/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ceafd04bc6860eaccfe5a480fb452f00792dac4
--- /dev/null
+++ b/dataset/utils.py
@@ -0,0 +1,116 @@
+#!/usr/bin/python
+#
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import PIL
+import torch
+import torchvision.transforms as T
+
+
+IMAGENET_MEAN = [0.485, 0.456, 0.406]
+IMAGENET_STD = [0.229, 0.224, 0.225]
+
+INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN]
+INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD]
+
+
+def imagenet_preprocess():
+ return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
+
+
+def rescale(x):
+ lo, hi = x.min(), x.max()
+ return x.sub(lo).div(hi - lo)
+
+
+def imagenet_deprocess(rescale_image=True):
+ transforms = [
+ T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD),
+ T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]),
+ ]
+ if rescale_image:
+ transforms.append(rescale)
+ return T.Compose(transforms)
+
+
+def imagenet_deprocess_batch(imgs, rescale=True):
+ """
+ Input:
+ - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images
+
+ Output:
+ - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images
+ in the range [0, 255]
+ """
+ if isinstance(imgs, torch.autograd.Variable):
+ imgs = imgs.data
+ imgs = imgs.cpu().clone()
+ deprocess_fn = imagenet_deprocess(rescale_image=rescale)
+ imgs_de = []
+ for i in range(imgs.size(0)):
+ img_de = deprocess_fn(imgs[i])[None]
+ img_de = img_de.mul(255).clamp(0, 255).byte()
+ imgs_de.append(img_de)
+ imgs_de = torch.cat(imgs_de, dim=0)
+ return imgs_de
+
+
+class Resize(object):
+ def __init__(self, size, interp=PIL.Image.BILINEAR):
+ if isinstance(size, tuple):
+ H, W = size
+ self.size = (W, H)
+ else:
+ self.size = (size, size)
+ self.interp = interp
+
+ def __call__(self, img):
+ return img.resize(self.size, self.interp)
+
+
+def unpack_var(v):
+ if isinstance(v, torch.autograd.Variable):
+ return v.data
+ return v
+
+
+def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img):
+ triples = unpack_var(triples)
+ obj_data = [unpack_var(o) for o in obj_data]
+ obj_to_img = unpack_var(obj_to_img)
+ triple_to_img = unpack_var(triple_to_img)
+
+ triples_out = []
+ obj_data_out = [[] for _ in obj_data]
+ obj_offset = 0
+ N = obj_to_img.max() + 1
+ for i in range(N):
+ o_idxs = (obj_to_img == i).nonzero().view(-1)
+ t_idxs = (triple_to_img == i).nonzero().view(-1)
+
+ cur_triples = triples[t_idxs].clone()
+ cur_triples[:, 0] -= obj_offset
+ cur_triples[:, 2] -= obj_offset
+ triples_out.append(cur_triples)
+
+ for j, o_data in enumerate(obj_data):
+ cur_o_data = None
+ if o_data is not None:
+ cur_o_data = o_data[o_idxs]
+ obj_data_out[j].append(cur_o_data)
+
+ obj_offset += o_idxs.size(0)
+
+ return triples_out, obj_data_out
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6fa931c37c2b460f4de1959a0e8fc5d777cf71c4
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,29 @@
+name: gligen_demo
+channels:
+ - xformers/label/dev
+ - pytorch
+ - defaults
+dependencies:
+ - python=3.10.8
+ - pip=22.2.2
+ - cudatoolkit=11.3
+ - pytorch=1.12.1
+ - torchvision=0.13.1
+ - numpy=1.23.1
+ - xformers
+ - pip:
+ - omegaconf==2.1.1
+ - albumentations==1.3.0
+ - opencv-python
+ - imageio==2.9.0
+ - imageio-ffmpeg==0.4.2
+ - pytorch-lightning==1.4.2
+ - test-tube>=0.7.5
+ - streamlit==1.12.1
+ - einops==0.3.0
+ - git+https://github.com/openai/CLIP.git
+ - protobuf~=3.20.1
+ - torchmetrics==0.6.0
+ - transformers==4.19.2
+ - kornia==0.6.0
+ - gradio==3.16.0
\ No newline at end of file
diff --git a/example_component.py b/example_component.py
new file mode 100644
index 0000000000000000000000000000000000000000..19fceb0d8abb853da6d66901201c0784930be8fe
--- /dev/null
+++ b/example_component.py
@@ -0,0 +1,805 @@
+"""
+Defines helper methods useful for loading and caching Interface examples.
+"""
+from __future__ import annotations
+
+import ast
+import csv
+import inspect
+import os
+import subprocess
+import tempfile
+import threading
+import warnings
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple
+
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import PIL
+import PIL.Image
+
+from gradio import components, processing_utils, routes, utils
+from gradio.context import Context
+from gradio.documentation import document, set_documentation_group
+from gradio.flagging import CSVLogger
+
+if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
+ from gradio.components import IOComponent
+
+CACHED_FOLDER = "gradio_cached_examples"
+LOG_FILE = "log.csv"
+
+set_documentation_group("helpers")
+
+
+def create_examples(
+ examples: List[Any] | List[List[Any]] | str,
+ inputs: IOComponent | List[IOComponent],
+ outputs: IOComponent | List[IOComponent] | None = None,
+ fn: Callable | None = None,
+ cache_examples: bool = False,
+ examples_per_page: int = 10,
+ _api_mode: bool = False,
+ label: str | None = None,
+ elem_id: str | None = None,
+ run_on_click: bool = False,
+ preprocess: bool = True,
+ postprocess: bool = True,
+ batch: bool = False,
+):
+ """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
+ examples_obj = Examples(
+ examples=examples,
+ inputs=inputs,
+ outputs=outputs,
+ fn=fn,
+ cache_examples=cache_examples,
+ examples_per_page=examples_per_page,
+ _api_mode=_api_mode,
+ label=label,
+ elem_id=elem_id,
+ run_on_click=run_on_click,
+ preprocess=preprocess,
+ postprocess=postprocess,
+ batch=batch,
+ _initiated_directly=False,
+ )
+ utils.synchronize_async(examples_obj.create)
+ return examples_obj
+
+
+class Examples:
+ """
+ This class is a wrapper over the Dataset component and can be used to create Examples
+ for Blocks / Interfaces. Populates the Dataset component with examples and
+ assigns event listener so that clicking on an example populates the input/output
+ components. Optionally handles example caching for fast inference.
+
+ Demos: blocks_inputs, fake_gan
+ Guides: more_on_examples_and_flagging, using_hugging_face_integrations, image_classification_in_pytorch, image_classification_in_tensorflow, image_classification_with_vision_transformers, create_your_own_friends_with_a_gan
+ """
+
+ def __init__(
+ self,
+ examples: List[Any] | List[List[Any]] | str,
+ inputs: IOComponent | List[IOComponent],
+ outputs: IOComponent | List[IOComponent] | None = None,
+ fn: Callable | None = None,
+ cache_examples: bool = False,
+ examples_per_page: int = 10,
+ _api_mode: bool = False,
+ label: str | None = "Examples",
+ elem_id: str | None = None,
+ run_on_click: bool = False,
+ preprocess: bool = True,
+ postprocess: bool = True,
+ batch: bool = False,
+ _initiated_directly: bool = True,
+ ):
+ """
+ Parameters:
+ examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
+ inputs: the component or list of components corresponding to the examples
+ outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
+ fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
+ cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
+ examples_per_page: how many examples to show per page.
+ label: the label to use for the examples component (by default, "Examples")
+ elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
+ run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True.
+ preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True.
+ postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True.
+ batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True.
+ """
+ if _initiated_directly:
+ warnings.warn(
+ "Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
+ )
+
+ if cache_examples and (fn is None or outputs is None):
+ raise ValueError("If caching examples, `fn` and `outputs` must be provided")
+
+ if not isinstance(inputs, list):
+ inputs = [inputs]
+ if outputs and not isinstance(outputs, list):
+ outputs = [outputs]
+
+ working_directory = Path().absolute()
+
+ if examples is None:
+ raise ValueError("The parameter `examples` cannot be None")
+ elif isinstance(examples, list) and (
+ len(examples) == 0 or isinstance(examples[0], list)
+ ):
+ pass
+ elif (
+ isinstance(examples, list) and len(inputs) == 1
+ ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
+ examples = [[e] for e in examples]
+ elif isinstance(examples, str):
+ if not Path(examples).exists():
+ raise FileNotFoundError(
+ "Could not find examples directory: " + examples
+ )
+ working_directory = examples
+ if not (Path(examples) / LOG_FILE).exists():
+ if len(inputs) == 1:
+ examples = [[e] for e in os.listdir(examples)]
+ else:
+ raise FileNotFoundError(
+ "Could not find log file (required for multiple inputs): "
+ + LOG_FILE
+ )
+ else:
+ with open(Path(examples) / LOG_FILE) as logs:
+ examples = list(csv.reader(logs))
+ examples = [
+ examples[i][: len(inputs)] for i in range(1, len(examples))
+ ] # remove header and unnecessary columns
+
+ else:
+ raise ValueError(
+ "The parameter `examples` must either be a string directory or a list"
+ "(if there is only 1 input component) or (more generally), a nested "
+ "list, where each sublist represents a set of inputs."
+ )
+
+ input_has_examples = [False] * len(inputs)
+ for example in examples:
+ for idx, example_for_input in enumerate(example):
+ if not (example_for_input is None):
+ try:
+ input_has_examples[idx] = True
+ except IndexError:
+ pass # If there are more example components than inputs, ignore. This can sometimes be intentional (e.g. loading from a log file where outputs and timestamps are also logged)
+
+ inputs_with_examples = [
+ inp for (inp, keep) in zip(inputs, input_has_examples) if keep
+ ]
+ non_none_examples = [
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
+ for example in examples
+ ]
+
+ self.examples = examples
+ self.non_none_examples = non_none_examples
+ self.inputs = inputs
+ self.inputs_with_examples = inputs_with_examples
+ self.outputs = outputs
+ self.fn = fn
+ self.cache_examples = cache_examples
+ self._api_mode = _api_mode
+ self.preprocess = preprocess
+ self.postprocess = postprocess
+ self.batch = batch
+
+ with utils.set_directory(working_directory):
+ self.processed_examples = [
+ [
+ component.postprocess(sample)
+ for component, sample in zip(inputs, example)
+ ]
+ for example in examples
+ ]
+ self.non_none_processed_examples = [
+ [ex for (ex, keep) in zip(example, input_has_examples) if keep]
+ for example in self.processed_examples
+ ]
+ if cache_examples:
+ for example in self.examples:
+ if len([ex for ex in example if ex is not None]) != len(self.inputs):
+ warnings.warn(
+ "Examples are being cached but not all input components have "
+ "example values. This may result in an exception being thrown by "
+ "your function. If you do get an error while caching examples, make "
+ "sure all of your inputs have example values for all of your examples "
+ "or you provide default values for those particular parameters in your function."
+ )
+ break
+
+ with utils.set_directory(working_directory):
+ self.dataset = components.Dataset(
+ components=inputs_with_examples,
+ samples=non_none_examples,
+ type="index",
+ label=label,
+ samples_per_page=examples_per_page,
+ elem_id=elem_id,
+ )
+
+ self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
+ self.cached_file = Path(self.cached_folder) / "log.csv"
+ self.cache_examples = cache_examples
+ self.run_on_click = run_on_click
+
+ async def create(self) -> None:
+ """Caches the examples if self.cache_examples is True and creates the Dataset
+ component to hold the examples"""
+
+ async def load_example(example_id):
+ # import pdb; pdb.set_trace()
+ if self.cache_examples:
+ processed_example = self.non_none_processed_examples[
+ example_id
+ ] + await self.load_from_cache(example_id)
+ else:
+ processed_example = self.non_none_processed_examples[example_id]
+ return utils.resolve_singleton(processed_example)
+
+ if Context.root_block:
+ if self.cache_examples and self.outputs:
+ targets = self.inputs_with_examples + self.outputs
+ else:
+ targets = self.inputs_with_examples
+ self.dataset.click(
+ load_example,
+ inputs=[self.dataset],
+ outputs=targets, # type: ignore
+ postprocess=False,
+ queue=False,
+ )
+ self.dataset.click(
+ self.fn,
+ inputs=[self.dataset],
+ outputs=targets, # type: ignore
+ postprocess=False,
+ queue=False,
+ )
+ # if self.run_on_click and not self.cache_examples:
+ # if self.fn is None:
+ # raise ValueError("Cannot run_on_click if no function is provided")
+ # self.dataset.click(
+ # self.fn,
+ # inputs=self.inputs, # type: ignore
+ # outputs=self.outputs, # type: ignore
+ # )
+
+ if self.cache_examples:
+ await self.cache()
+
+ async def cache(self) -> None:
+ """
+ Caches all of the examples so that their predictions can be shown immediately.
+ """
+ if Path(self.cached_file).exists():
+ print(
+ f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
+ )
+ else:
+ if Context.root_block is None:
+ raise ValueError("Cannot cache examples if not in a Blocks context")
+
+ print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
+ cache_logger = CSVLogger()
+
+ # create a fake dependency to process the examples and get the predictions
+ dependency = Context.root_block.set_event_trigger(
+ event_name="fake_event",
+ fn=self.fn,
+ inputs=self.inputs_with_examples, # type: ignore
+ outputs=self.outputs, # type: ignore
+ preprocess=self.preprocess and not self._api_mode,
+ postprocess=self.postprocess and not self._api_mode,
+ batch=self.batch,
+ )
+
+ fn_index = Context.root_block.dependencies.index(dependency)
+ assert self.outputs is not None
+ cache_logger.setup(self.outputs, self.cached_folder)
+ for example_id, _ in enumerate(self.examples):
+ processed_input = self.processed_examples[example_id]
+ if self.batch:
+ processed_input = [[value] for value in processed_input]
+ prediction = await Context.root_block.process_api(
+ fn_index=fn_index, inputs=processed_input, request=None, state={}
+ )
+ output = prediction["data"]
+ if self.batch:
+ output = [value[0] for value in output]
+ cache_logger.flag(output)
+ # Remove the "fake_event" to prevent bugs in loading interfaces from spaces
+ Context.root_block.dependencies.remove(dependency)
+ Context.root_block.fns.pop(fn_index)
+
+ async def load_from_cache(self, example_id: int) -> List[Any]:
+ """Loads a particular cached example for the interface.
+ Parameters:
+ example_id: The id of the example to process (zero-indexed).
+ """
+ # import pdb; pdb.set_trace()
+ with open(self.cached_file, encoding="utf-8") as cache:
+ examples = list(csv.reader(cache))
+ example = examples[example_id + 1] # +1 to adjust for header
+ output = []
+ assert self.outputs is not None
+ for component, value in zip(self.outputs, example):
+ try:
+ value_as_dict = ast.literal_eval(value)
+ assert utils.is_update(value_as_dict)
+ output.append(value_as_dict)
+ except (ValueError, TypeError, SyntaxError, AssertionError):
+ output.append(component.serialize(value, self.cached_folder))
+ return output
+
+
+class TrackedIterable:
+ def __init__(
+ self,
+ iterable: Iterable | None,
+ index: int | None,
+ length: int | None,
+ desc: str | None,
+ unit: str | None,
+ _tqdm=None,
+ progress: float | None = None,
+ ) -> None:
+ self.iterable = iterable
+ self.index = index
+ self.length = length
+ self.desc = desc
+ self.unit = unit
+ self._tqdm = _tqdm
+ self.progress = progress
+
+
+@document("__call__", "tqdm")
+class Progress(Iterable):
+ """
+ The Progress class provides a custom progress tracker that is used in a function signature.
+ To attach a Progress tracker to a function, simply add a parameter right after the input parameters that has a default value set to a `gradio.Progress()` instance.
+ The Progress tracker can then be updated in the function by calling the Progress object or using the `tqdm` method on an Iterable.
+ The Progress tracker is currently only available with `queue()`.
+ Example:
+ import gradio as gr
+ import time
+ def my_function(x, progress=gr.Progress()):
+ progress(0, desc="Starting...")
+ time.sleep(1)
+ for i in progress.tqdm(range(100)):
+ time.sleep(0.1)
+ return x
+ gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue().launch()
+ Demos: progress
+ """
+
+ def __init__(
+ self,
+ track_tqdm: bool = False,
+ _callback: Callable | None = None, # for internal use only
+ _event_id: str | None = None,
+ ):
+ """
+ Parameters:
+ track_tqdm: If True, the Progress object will track any tqdm.tqdm iterations with the tqdm library in the function.
+ """
+ self.track_tqdm = track_tqdm
+ self._callback = _callback
+ self._event_id = _event_id
+ self.iterables: List[TrackedIterable] = []
+
+ def __len__(self):
+ return self.iterables[-1].length
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ """
+ Updates progress tracker with next item in iterable.
+ """
+ if self._callback:
+ current_iterable = self.iterables[-1]
+ while (
+ not hasattr(current_iterable.iterable, "__next__")
+ and len(self.iterables) > 0
+ ):
+ current_iterable = self.iterables.pop()
+ self._callback(
+ event_id=self._event_id,
+ iterables=self.iterables,
+ )
+ assert current_iterable.index is not None, "Index not set."
+ current_iterable.index += 1
+ try:
+ return next(current_iterable.iterable) # type: ignore
+ except StopIteration:
+ self.iterables.pop()
+ raise StopIteration
+ else:
+ return self
+
+ def __call__(
+ self,
+ progress: float | Tuple[int, int | None] | None,
+ desc: str | None = None,
+ total: int | None = None,
+ unit: str = "steps",
+ _tqdm=None,
+ ):
+ """
+ Updates progress tracker with progress and message text.
+ Parameters:
+ progress: If float, should be between 0 and 1 representing completion. If Tuple, first number represents steps completed, and second value represents total steps or None if unknown. If None, hides progress bar.
+ desc: description to display.
+ total: estimated total number of steps.
+ unit: unit of iterations.
+ """
+ if self._callback:
+ if isinstance(progress, tuple):
+ index, total = progress
+ progress = None
+ else:
+ index = None
+ self._callback(
+ event_id=self._event_id,
+ iterables=self.iterables
+ + [TrackedIterable(None, index, total, desc, unit, _tqdm, progress)],
+ )
+ else:
+ return progress
+
+ def tqdm(
+ self,
+ iterable: Iterable | None,
+ desc: str | None = None,
+ total: int | None = None,
+ unit: str = "steps",
+ _tqdm=None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Attaches progress tracker to iterable, like tqdm.
+ Parameters:
+ iterable: iterable to attach progress tracker to.
+ desc: description to display.
+ total: estimated total number of steps.
+ unit: unit of iterations.
+ """
+ if self._callback:
+ if iterable is None:
+ new_iterable = TrackedIterable(None, 0, total, desc, unit, _tqdm)
+ self.iterables.append(new_iterable)
+ self._callback(event_id=self._event_id, iterables=self.iterables)
+ return self
+ length = len(iterable) if hasattr(iterable, "__len__") else None # type: ignore
+ self.iterables.append(
+ TrackedIterable(iter(iterable), 0, length, desc, unit, _tqdm)
+ )
+ return self
+
+ def update(self, n=1):
+ """
+ Increases latest iterable with specified number of steps.
+ Parameters:
+ n: number of steps completed.
+ """
+ if self._callback and len(self.iterables) > 0:
+ current_iterable = self.iterables[-1]
+ assert current_iterable.index is not None, "Index not set."
+ current_iterable.index += n
+ self._callback(
+ event_id=self._event_id,
+ iterables=self.iterables,
+ )
+ else:
+ return
+
+ def close(self, _tqdm):
+ """
+ Removes iterable with given _tqdm.
+ """
+ if self._callback:
+ for i in range(len(self.iterables)):
+ if id(self.iterables[i]._tqdm) == id(_tqdm):
+ self.iterables.pop(i)
+ break
+ self._callback(
+ event_id=self._event_id,
+ iterables=self.iterables,
+ )
+ else:
+ return
+
+
+def create_tracker(root_blocks, event_id, fn, track_tqdm):
+
+ progress = Progress(_callback=root_blocks._queue.set_progress, _event_id=event_id)
+ if not track_tqdm:
+ return progress, fn
+
+ try:
+ _tqdm = __import__("tqdm")
+ except ModuleNotFoundError:
+ return progress, fn
+ if not hasattr(root_blocks, "_progress_tracker_per_thread"):
+ root_blocks._progress_tracker_per_thread = {}
+
+ def init_tqdm(self, iterable=None, desc=None, *args, **kwargs):
+ self._progress = root_blocks._progress_tracker_per_thread.get(
+ threading.get_ident()
+ )
+ if self._progress is not None:
+ self._progress.event_id = event_id
+ self._progress.tqdm(iterable, desc, _tqdm=self, *args, **kwargs)
+ kwargs["file"] = open(os.devnull, "w")
+ self.__init__orig__(iterable, desc, *args, **kwargs)
+
+ def iter_tqdm(self):
+ if self._progress is not None:
+ return self._progress
+ else:
+ return self.__iter__orig__()
+
+ def update_tqdm(self, n=1):
+ if self._progress is not None:
+ self._progress.update(n)
+ return self.__update__orig__(n)
+
+ def close_tqdm(self):
+ if self._progress is not None:
+ self._progress.close(self)
+ return self.__close__orig__()
+
+ def exit_tqdm(self, exc_type, exc_value, traceback):
+ if self._progress is not None:
+ self._progress.close(self)
+ return self.__exit__orig__(exc_type, exc_value, traceback)
+
+ if not hasattr(_tqdm.tqdm, "__init__orig__"):
+ _tqdm.tqdm.__init__orig__ = _tqdm.tqdm.__init__
+ _tqdm.tqdm.__init__ = init_tqdm
+ if not hasattr(_tqdm.tqdm, "__update__orig__"):
+ _tqdm.tqdm.__update__orig__ = _tqdm.tqdm.update
+ _tqdm.tqdm.update = update_tqdm
+ if not hasattr(_tqdm.tqdm, "__close__orig__"):
+ _tqdm.tqdm.__close__orig__ = _tqdm.tqdm.close
+ _tqdm.tqdm.close = close_tqdm
+ if not hasattr(_tqdm.tqdm, "__exit__orig__"):
+ _tqdm.tqdm.__exit__orig__ = _tqdm.tqdm.__exit__
+ _tqdm.tqdm.__exit__ = exit_tqdm
+ if not hasattr(_tqdm.tqdm, "__iter__orig__"):
+ _tqdm.tqdm.__iter__orig__ = _tqdm.tqdm.__iter__
+ _tqdm.tqdm.__iter__ = iter_tqdm
+ if hasattr(_tqdm, "auto") and hasattr(_tqdm.auto, "tqdm"):
+ _tqdm.auto.tqdm = _tqdm.tqdm
+
+ def tracked_fn(*args):
+ thread_id = threading.get_ident()
+ root_blocks._progress_tracker_per_thread[thread_id] = progress
+ response = fn(*args)
+ del root_blocks._progress_tracker_per_thread[thread_id]
+ return response
+
+ return progress, tracked_fn
+
+
+def special_args(
+ fn: Callable,
+ inputs: List[Any] | None = None,
+ request: routes.Request | None = None,
+):
+ """
+ Checks if function has special arguments Request (via annotation) or Progress (via default value).
+ If inputs is provided, these values will be loaded into the inputs array.
+ Parameters:
+ block_fn: function to check.
+ inputs: array to load special arguments into.
+ request: request to load into inputs.
+ Returns:
+ updated inputs, request index, progress index
+ """
+ signature = inspect.signature(fn)
+ positional_args = []
+ for i, param in enumerate(signature.parameters.values()):
+ if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
+ break
+ positional_args.append(param)
+ progress_index = None
+ for i, param in enumerate(positional_args):
+ if isinstance(param.default, Progress):
+ progress_index = i
+ if inputs is not None:
+ inputs.insert(i, param.default)
+ elif param.annotation == routes.Request:
+ if inputs is not None:
+ inputs.insert(i, request)
+ if inputs is not None:
+ while len(inputs) < len(positional_args):
+ i = len(inputs)
+ param = positional_args[i]
+ if param.default == param.empty:
+ warnings.warn("Unexpected argument. Filling with None.")
+ inputs.append(None)
+ else:
+ inputs.append(param.default)
+ return inputs or [], progress_index
+
+
+@document()
+def update(**kwargs) -> dict:
+ """
+ Updates component properties. When a function passed into a Gradio Interface or a Blocks events returns a typical value, it updates the value of the output component. But it is also possible to update the properties of an output component (such as the number of lines of a `Textbox` or the visibility of an `Image`) by returning the component's `update()` function, which takes as parameters any of the constructor parameters for that component.
+ This is a shorthand for using the update method on a component.
+ For example, rather than using gr.Number.update(...) you can just use gr.update(...).
+ Note that your editor's autocompletion will suggest proper parameters
+ if you use the update method on the component.
+ Demos: blocks_essay, blocks_update, blocks_essay_update
+
+ Parameters:
+ kwargs: Key-word arguments used to update the component's properties.
+ Example:
+ # Blocks Example
+ import gradio as gr
+ with gr.Blocks() as demo:
+ radio = gr.Radio([1, 2, 4], label="Set the value of the number")
+ number = gr.Number(value=2, interactive=True)
+ radio.change(fn=lambda value: gr.update(value=value), inputs=radio, outputs=number)
+ demo.launch()
+
+ # Interface example
+ import gradio as gr
+ def change_textbox(choice):
+ if choice == "short":
+ return gr.Textbox.update(lines=2, visible=True)
+ elif choice == "long":
+ return gr.Textbox.update(lines=8, visible=True)
+ else:
+ return gr.Textbox.update(visible=False)
+ gr.Interface(
+ change_textbox,
+ gr.Radio(
+ ["short", "long", "none"], label="What kind of essay would you like to write?"
+ ),
+ gr.Textbox(lines=2),
+ live=True,
+ ).launch()
+ """
+ kwargs["__type__"] = "generic_update"
+ return kwargs
+
+
+def skip() -> dict:
+ return update()
+
+
+@document()
+def make_waveform(
+ audio: str | Tuple[int, np.ndarray],
+ *,
+ bg_color: str = "#f3f4f6",
+ bg_image: str | None = None,
+ fg_alpha: float = 0.75,
+ bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"),
+ bar_count: int = 50,
+ bar_width: float = 0.6,
+):
+ """
+ Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.
+ Parameters:
+ audio: Audio file path or tuple of (sample_rate, audio_data)
+ bg_color: Background color of waveform (ignored if bg_image is provided)
+ bg_image: Background image of waveform
+ fg_alpha: Opacity of foreground waveform
+ bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient
+ bar_count: Number of bars in waveform
+ bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc.
+ Returns:
+ A filepath to the output video.
+ """
+ if isinstance(audio, str):
+ audio_file = audio
+ audio = processing_utils.audio_from_file(audio)
+ else:
+ tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
+ processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name)
+ audio_file = tmp_wav.name
+ duration = round(len(audio[1]) / audio[0], 4)
+
+ # Helper methods to create waveform
+ def hex_to_RGB(hex_str):
+ return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
+
+ def get_color_gradient(c1, c2, n):
+ assert n > 1
+ c1_rgb = np.array(hex_to_RGB(c1)) / 255
+ c2_rgb = np.array(hex_to_RGB(c2)) / 255
+ mix_pcts = [x / (n - 1) for x in range(n)]
+ rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
+ return [
+ "#" + "".join([format(int(round(val * 255)), "02x") for val in item])
+ for item in rgb_colors
+ ]
+
+ # Reshape audio to have a fixed number of bars
+ samples = audio[1]
+ if len(samples.shape) > 1:
+ samples = np.mean(samples, 1)
+ bins_to_pad = bar_count - (len(samples) % bar_count)
+ samples = np.pad(samples, [(0, bins_to_pad)])
+ samples = np.reshape(samples, (bar_count, -1))
+ samples = np.abs(samples)
+ samples = np.max(samples, 1)
+
+ matplotlib.use("Agg")
+ plt.clf()
+ # Plot waveform
+ color = (
+ bars_color
+ if isinstance(bars_color, str)
+ else get_color_gradient(bars_color[0], bars_color[1], bar_count)
+ )
+ plt.bar(
+ np.arange(0, bar_count),
+ samples * 2,
+ bottom=(-1 * samples),
+ width=bar_width,
+ color=color,
+ )
+ plt.axis("off")
+ plt.margins(x=0)
+ tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
+ savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"}
+ if bg_image is not None:
+ savefig_kwargs["transparent"] = True
+ else:
+ savefig_kwargs["facecolor"] = bg_color
+ plt.savefig(tmp_img.name, **savefig_kwargs)
+ waveform_img = PIL.Image.open(tmp_img.name)
+ waveform_img = waveform_img.resize((1000, 200))
+
+ # Composite waveform with background image
+ if bg_image is not None:
+ waveform_array = np.array(waveform_img)
+ waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha
+ waveform_img = PIL.Image.fromarray(waveform_array)
+
+ bg_img = PIL.Image.open(bg_image)
+ waveform_width, waveform_height = waveform_img.size
+ bg_width, bg_height = bg_img.size
+ if waveform_width != bg_width:
+ bg_img = bg_img.resize(
+ (waveform_width, 2 * int(bg_height * waveform_width / bg_width / 2))
+ )
+ bg_width, bg_height = bg_img.size
+ composite_height = max(bg_height, waveform_height)
+ composite = PIL.Image.new("RGBA", (waveform_width, composite_height), "#FFFFFF")
+ composite.paste(bg_img, (0, composite_height - bg_height))
+ composite.paste(
+ waveform_img, (0, composite_height - waveform_height), waveform_img
+ )
+ composite.save(tmp_img.name)
+ img_width, img_height = composite.size
+ else:
+ img_width, img_height = waveform_img.size
+ waveform_img.save(tmp_img.name)
+
+ # Convert waveform to video with ffmpeg
+ output_mp4 = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
+
+ ffmpeg_cmd = f"""ffmpeg -loop 1 -i {tmp_img.name} -i {audio_file} -vf "color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1" -t {duration} -y {output_mp4.name}"""
+
+ subprocess.call(ffmpeg_cmd, shell=True)
+ return output_mp4.name
diff --git a/gligen/.DS_Store b/gligen/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..b880aeb4826dd12db4035e8c2abbda8457d64eaa
Binary files /dev/null and b/gligen/.DS_Store differ
diff --git a/gligen/SD_input_conv_weight_bias.pth b/gligen/SD_input_conv_weight_bias.pth
new file mode 100644
index 0000000000000000000000000000000000000000..76eed06e176fec68ff2d1c4a3fd179cf620a7d7d
--- /dev/null
+++ b/gligen/SD_input_conv_weight_bias.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b5a0efad69747a766158304f39091c2b6a24cafb5f833d174f32bee8e864a562
+size 130
diff --git a/gligen/__init__.py b/gligen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..67cf72156e8a5586636f0af71bb47be11a7db307
--- /dev/null
+++ b/gligen/__init__.py
@@ -0,0 +1,10 @@
+
+import os, sys
+sys.path.append(os.path.dirname(__file__))
+sys.path.append(os.path.join(os.path.dirname(__file__), "ldm"))
+
+import gligen.evaluator as evaluator
+import gligen.trainer as trainer
+
+
+# import gligen.ldm as ldm
\ No newline at end of file
diff --git a/gligen/__pycache__/__init__.cpython-38.pyc b/gligen/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dddf0891f9b86a5a4c19aad9273cdac089f3782
Binary files /dev/null and b/gligen/__pycache__/__init__.cpython-38.pyc differ
diff --git a/gligen/__pycache__/distributed.cpython-38.pyc b/gligen/__pycache__/distributed.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7e01e94be05a8a6f8afe0291682765163cf72ac
Binary files /dev/null and b/gligen/__pycache__/distributed.cpython-38.pyc differ
diff --git a/gligen/__pycache__/evaluator.cpython-38.pyc b/gligen/__pycache__/evaluator.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b775a0d708154f3ead33f5fd96f7d51fcf266103
Binary files /dev/null and b/gligen/__pycache__/evaluator.cpython-38.pyc differ
diff --git a/gligen/__pycache__/task_grounded_generation.cpython-38.pyc b/gligen/__pycache__/task_grounded_generation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d5b36ff062326f901f0a31f3f547ce95993e4fd
Binary files /dev/null and b/gligen/__pycache__/task_grounded_generation.cpython-38.pyc differ
diff --git a/gligen/__pycache__/trainer.cpython-38.pyc b/gligen/__pycache__/trainer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16ac8ae5d0a4e8c3ade40746a342320c7aab222b
Binary files /dev/null and b/gligen/__pycache__/trainer.cpython-38.pyc differ
diff --git a/gligen/create_meta.py b/gligen/create_meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..7512c6d377df98db7e17515a7143b7a4ef7d5f32
--- /dev/null
+++ b/gligen/create_meta.py
@@ -0,0 +1,170 @@
+CKPTS = [
+
+ dict(
+ path="/home/chunyl/azure_mount/yuhengdb/fine_tune_ldm/version5_branch6_output/GoldG+SBU+CC3M+CC12M+O365/second_stage_drop_both/tag01/checkpoint_00450001.pth",
+ feature_type=['before','after_reproject'],
+ save_folder_name="v5b6_drop_both",
+ ),
+
+
+ # dict(
+ # path="/home/v-yuhengli/blobfuse/output/fine_tune_ldm/version5_branch6_output/GoldG+SBU+CC3M+CC12M+O365/second_stage_drop_none/tag00/checkpoint_00165001.pth",
+ # feature_type=['before','after_reproject'],
+ # save_folder_name="v5b6_drop_none",
+ # ),
+
+
+
+
+
+]
+
+
+
+# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
+
+
+
+
+
+
+
+
+ # if meta["has_image_mask"] == 0:
+ # image_embeddings = text_embeddings
+ # if meta["has_text_mask"] == 0:
+ # text_embeddings = image_embeddings
+
+ # out = {
+ # "boxes" : boxes.unsqueeze(0).repeat(batch,1,1),
+ # "masks" : masks.unsqueeze(0).repeat(batch,1),
+ # "text_masks" : masks.unsqueeze(0).repeat(batch,1),
+ # "image_masks" : masks.unsqueeze(0).repeat(batch,1),
+ # "text_embeddings" : text_embeddings.unsqueeze(0).repeat(batch,1,1),
+ # "image_embeddings" : image_embeddings.unsqueeze(0).repeat(batch,1,1)
+ # }
+
+
+
+
+
+
+
+META = [
+
+
+ dict(
+ prompt = "a teddy bear sitting next to a red bird",
+ phrases = ['a teddy bear', 'a red bird'],
+ images = ['images/teddy.jpg', 'images/red_bird.jpg'],
+ locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ],
+ alpha_type = [1.0, 0, 0.0],
+ has_text_mask = 1,
+ has_image_mask = 0,
+ save_folder_name="teddy_bird_1_1"
+ ),
+
+
+ # dict(
+ # prompt = "a teddy bear sitting next to a bird",
+ # phrases = ['a teddy bear', 'a bird'],
+ # images = ['images/teddy.jpg', 'images/red_bird.jpg'],
+ # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ],
+ # alpha_type = [1.0, 0, 0.0],
+ # has_text_mask = 1,
+ # has_image_mask = 1,
+ # save_folder_name="teddy_bird_1_1"
+ # ),
+
+
+ # dict(
+ # prompt = "a teddy bear sitting next to a bird",
+ # phrases = ['a teddy bear', 'a bird'],
+ # images = ['images/teddy.jpg', 'images/red_bird.jpg'],
+ # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ],
+ # alpha_type = [0.5, 0, 0.5],
+ # has_text_mask = 1,
+ # has_image_mask = 0,
+ # save_folder_name="teddy_bird_1_0"
+ # ),
+
+ # dict(
+ # prompt = "",
+ # phrases = ['a teddy bear', 'an umbrella'],
+ # images = ['images/teddy.jpg', 'images/umbrella.png'],
+ # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ],
+ # alpha_type = [1.0, 0, 0.0],
+ # has_text_mask = 1,
+ # has_image_mask = 1,
+ # save_folder_name="empty_teddy_umbrella_1_1"
+ # ),
+
+ # dict(
+ # prompt = "hello kitty and bird hybrid",
+ # phrases = ['a hello kitty', 'a hello kitty'],
+ # images = ['images/red_bird.jpg', 'images/red_bird.jpg'],
+ # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ],
+ # has_text_mask = 1,
+ # has_image_mask = 1,
+ # save_folder_name="hello+bird_1_1"
+ # ),
+
+ # dict(
+ # prompt = "hello kitty and teddy bear hybrid",
+ # phrases = ['a hello kitty', 'a hello kitty'],
+ # images = ['images/teddy.jpg', 'images/teddy.jpg'],
+ # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ],
+ # has_text_mask = 1,
+ # has_image_mask = 1,
+ # save_folder_name="hello+teddy_1_1"
+ # ),
+
+ # dict(
+ # prompt = "bird and hello kitty hybrid",
+ # phrases = ['a bird', 'a bird'],
+ # images = ['images/hello.jpg', 'images/hello.jpg'],
+ # locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ],
+ # alpha_type = [1.0, 0, 0.0],
+ # has_text_mask = 1,
+ # has_image_mask = 0.5,
+ # save_folder_name="bird+hello_1_1"
+ # ),
+
+
+
+ # dict(
+ # prompt = "a deer standing in front of a brick house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k",
+ # phrases = ['a deer'],
+ # images = ['images/sky.jpg'],
+ # locations = [ [0.0,0.5,0.5,0.9] ],
+ # alpha_type = [1, 0, 0],
+ # has_text_mask = 1,
+ # has_image_mask = 1,
+ # save_folder_name="deer_sky"
+ # ),
+
+
+ # dict(
+ # prompt = "A woman sitting in a restaurant with a slice of pizza in front of her",
+ # phrases = ['dining table', 'pizza', 'person', 'wall', 'car', 'paper', 'chair', 'window', 'bottle', 'cup'],
+ # images = ['images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg','images/hello.jpg'],
+ # locations = [ [0.0030, 0.3589, 1.0000, 1.0000],
+ # [0.0779, 0.6744, 0.9768, 1.0000],
+ # [0.2236, 0.0000, 0.7809, 0.4352],
+ # [0.0000, 0.0000, 0.4313, 0.4505],
+ # [0.6275, 0.1050, 0.9444, 0.2497],
+ # [0.0000, 0.3859, 0.1250, 0.6922],
+ # [0.7137, 0.2389, 0.8540, 0.4549],
+ # [0.0000, 0.0000, 0.4667, 0.0630],
+ # [0.3822, 0.4235, 0.4932, 0.6575],
+ # [0.6616, 0.3617, 0.7880, 0.5165] ],
+ # alpha_type = [0.0, 0, 1.0],
+ # has_text_mask = 1,
+ # has_image_mask = 0,
+ # save_folder_name="pizza_1_0"
+ # ),
+
+
+
+
+]
\ No newline at end of file
diff --git a/gligen/distributed.py b/gligen/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..b39bc6e92f74fc46c6ec316e1e41859744a91b7a
--- /dev/null
+++ b/gligen/distributed.py
@@ -0,0 +1,122 @@
+import math
+import pickle
+
+import torch
+from torch import distributed as dist
+from torch.utils.data.sampler import Sampler
+
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+
+ if not dist.is_initialized():
+ return 0
+
+ return dist.get_rank()
+
+
+def synchronize():
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+
+ dist.barrier()
+
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def reduce_sum(tensor):
+ if not dist.is_available():
+ return tensor
+
+ if not dist.is_initialized():
+ return tensor
+
+ tensor = tensor.clone()
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+
+ return tensor
+
+
+def gather_grad(params):
+ world_size = get_world_size()
+
+ if world_size == 1:
+ return
+
+ for param in params:
+ if param.grad is not None:
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
+ param.grad.data.div_(world_size)
+
+
+def all_gather(data):
+ world_size = get_world_size()
+
+ if world_size == 1:
+ return [data]
+
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to('cuda')
+
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
+
+ if local_size != max_size:
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
+ tensor = torch.cat((tensor, padding), 0)
+
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_loss_dict(loss_dict):
+ world_size = get_world_size()
+
+ if world_size < 2:
+ return loss_dict
+
+ with torch.no_grad():
+ keys = []
+ losses = []
+
+ for k in sorted(loss_dict.keys()):
+ keys.append(k)
+ losses.append(loss_dict[k])
+
+ losses = torch.stack(losses, 0)
+ dist.reduce(losses, dst=0)
+
+ if dist.get_rank() == 0:
+ losses /= world_size
+
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
+
+ return reduced_losses
diff --git a/gligen/evaluator.py b/gligen/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..436c3d9b1c733bf3a3cc1ff027eb08d03b2d2fed
--- /dev/null
+++ b/gligen/evaluator.py
@@ -0,0 +1,225 @@
+import torch
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.util import instantiate_from_config
+import numpy as np
+import random
+from dataset.concat_dataset import ConCatDataset #, collate_fn
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+import os
+from tqdm import tqdm
+from distributed import get_rank, synchronize, get_world_size
+from trainer import read_official_ckpt, batch_to_device, ImageCaptionSaver, wrap_loader #, get_padded_boxes
+from PIL import Image
+import math
+import json
+#hello
+
+def draw_masks_from_boxes(boxes,size):
+
+ image_masks = []
+ for box in boxes:
+ image_mask = torch.ones(size[0],size[1])
+ for bx in box:
+ x0, x1 = bx[0]*size[0], bx[2]*size[0]
+ y0, y1 = bx[1]*size[1], bx[3]*size[1]
+ image_mask[int(y0):int(y1), int(x0):int(x1)] = 0
+ image_masks.append(image_mask)
+ return torch.stack(image_masks).unsqueeze(1)
+
+
+
+def set_alpha_scale(model, alpha_scale):
+ from ldm.modules.attention import GatedCrossAttentionDense, GatedSelfAttentionDense
+ for module in model.modules():
+ if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense:
+ module.scale = alpha_scale
+ # print("scale: ", alpha_scale)
+ # print("attn: ", module.alpha_attn)
+ # print("dense: ", module.alpha_dense)
+ # print(' ')
+ # print(' ')
+
+
+def save_images(samples, image_ids, folder, to256):
+ for sample, image_id in zip(samples, image_ids):
+ sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
+ sample = sample.cpu().numpy().transpose(1,2,0) * 255
+ img_name = str(int(image_id))+'.png'
+ img = Image.fromarray(sample.astype(np.uint8))
+ if to256:
+ img = img.resize( (256,256), Image.BICUBIC)
+ img.save(os.path.join(folder,img_name))
+
+
+def ckpt_to_folder_name(basename):
+ name=""
+ for s in basename:
+ if s.isdigit():
+ name+=s
+ seen = round( int(name)/1000, 1 )
+ return str(seen).ljust(4,'0')+'k'
+
+
+class Evaluator:
+ def __init__(self, config):
+
+ self.config = config
+ self.device = torch.device("cuda")
+
+
+ # = = = = = create model and diffusion = = = = = #
+ if self.config.ckpt != "real":
+
+ self.model = instantiate_from_config(config.model).to(self.device)
+ self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device)
+ self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device)
+ self.diffusion = instantiate_from_config(config.diffusion).to(self.device)
+
+ # donot need to load official_ckpt for self.model here, since we will load from our ckpt
+ state_dict = read_official_ckpt( os.path.join(config.DATA_ROOT, config.official_ckpt_name) )
+ self.autoencoder.load_state_dict( state_dict["autoencoder"] )
+ self.text_encoder.load_state_dict( state_dict["text_encoder"] )
+ self.diffusion.load_state_dict( state_dict["diffusion"] )
+
+
+ # = = = = = load from our ckpt = = = = = #
+ if self.config.ckpt == "real":
+ print("Saving all real images...")
+ self.just_save_real = True
+ else:
+ checkpoint = torch.load(self.config.ckpt, map_location="cpu")
+ which_state = 'ema' if 'ema' in checkpoint else "model"
+ which_state = which_state if config.which_state is None else config.which_state
+ self.model.load_state_dict(checkpoint[which_state])
+ print("ckpt is loaded")
+ self.just_save_real = False
+ set_alpha_scale(self.model, self.config.alpha_scale)
+
+ self.autoencoder.eval()
+ self.model.eval()
+ self.text_encoder.eval()
+
+
+ # = = = = = create data = = = = = #
+ self.dataset_eval = ConCatDataset(config.val_dataset_names, config.DATA_ROOT, config.which_embedder, train=False)
+ print("total eval images: ", len(self.dataset_eval))
+ sampler = DistributedSampler(self.dataset_eval,shuffle=False) if config.distributed else None
+ loader_eval = DataLoader( self.dataset_eval,batch_size=config.batch_size,
+ num_workers=config.workers,
+ pin_memory=True,
+ sampler=sampler,
+ drop_last=False) # shuffle default is False
+ self.loader_eval = loader_eval
+
+
+ # = = = = = create output folder = = = = = #
+ folder_name = ckpt_to_folder_name(os.path.basename(config.ckpt))
+ self.outdir = os.path.join(config.OUTPUT_ROOT, folder_name)
+ self.outdir_real = os.path.join(self.outdir,'real')
+ self.outdir_fake = os.path.join(self.outdir,'fake')
+ if config.to256:
+ self.outdir_real256 = os.path.join(self.outdir,'real256')
+ self.outdir_fake256 = os.path.join(self.outdir,'fake256')
+ synchronize() # if rank0 is faster, it may mkdir before the other rank call os.listdir()
+ if get_rank() == 0:
+ os.makedirs(self.outdir, exist_ok=True)
+ os.makedirs(self.outdir_real, exist_ok=True)
+ os.makedirs(self.outdir_fake, exist_ok=True)
+ if config.to256:
+ os.makedirs(self.outdir_real256, exist_ok=True)
+ os.makedirs(self.outdir_fake256, exist_ok=True)
+ print(self.outdir) # double check
+
+ self.evaluation_finished = False
+ if os.path.exists( os.path.join(self.outdir,'score.txt') ):
+ self.evaluation_finished = True
+
+
+ def alread_saved_this_batch(self, batch):
+ existing_real_files = os.listdir( self.outdir_real )
+ existing_fake_files = os.listdir( self.outdir_fake )
+ status = []
+ for image_id in batch["id"]:
+ img_name = str(int(image_id))+'.png'
+ status.append(img_name in existing_real_files)
+ status.append(img_name in existing_fake_files)
+ return all(status)
+
+
+ @torch.no_grad()
+ def start_evaluating(self):
+
+ iterator = tqdm( self.loader_eval, desc='Evaluating progress')
+ for batch in iterator:
+
+ #if not self.alread_saved_this_batch(batch):
+ if True:
+
+ batch_to_device(batch, self.device)
+ batch_size = batch["image"].shape[0]
+ samples_real = batch["image"]
+
+ if self.just_save_real:
+ samples_fake = None
+ else:
+ uc = self.text_encoder.encode( batch_size*[""] )
+ context = self.text_encoder.encode( batch["caption"] )
+
+ image_mask = x0 = None
+ if self.config.inpaint:
+ image_mask = draw_masks_from_boxes( batch['boxes'], self.model.image_size ).cuda()
+ x0 = self.autoencoder.encode( batch["image"] )
+
+ shape = (batch_size, self.model.in_channels, self.model.image_size, self.model.image_size)
+ if self.config.no_plms:
+ sampler = DDIMSampler(self.diffusion, self.model)
+ steps = 250
+ else:
+ sampler = PLMSSampler(self.diffusion, self.model)
+ steps = 50
+
+ input = dict( x=None, timesteps=None, context=context, boxes=batch['boxes'], masks=batch['masks'], positive_embeddings=batch["positive_embeddings"] )
+ samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=self.config.guidance_scale, mask=image_mask, x0=x0)
+ samples_fake = self.autoencoder.decode(samples_fake)
+
+
+ save_images(samples_real, batch['id'], self.outdir_real, to256=False )
+ if self.config.to256:
+ save_images(samples_real, batch['id'], self.outdir_real256, to256=True )
+
+ if samples_fake is not None:
+ save_images(samples_fake, batch['id'], self.outdir_fake, to256=False )
+ if self.config.to256:
+ save_images(samples_fake, batch['id'], self.outdir_fake256, to256=True )
+
+
+ def fire_fid(self):
+ paths = [self.outdir_real, self.outdir_fake]
+ if self.config.to256:
+ paths = [self.outdir_real256, self.outdir_fake256]
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/gligen/ldm/.DS_Store b/gligen/ldm/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..cc8962d7b5d196908d8d336eb5b39dc4d7ee7b02
Binary files /dev/null and b/gligen/ldm/.DS_Store differ
diff --git a/gligen/ldm/__pycache__/util.cpython-38.pyc b/gligen/ldm/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c49bacc41c35909fac61e2ecc1c916fc1ffb7605
Binary files /dev/null and b/gligen/ldm/__pycache__/util.cpython-38.pyc differ
diff --git a/gligen/ldm/data/.DS_Store b/gligen/ldm/data/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/gligen/ldm/data/.DS_Store differ
diff --git a/gligen/ldm/data/__init__.py b/gligen/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/gligen/ldm/data/base.py b/gligen/ldm/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b196c2f7aa583a3e8bc4aad9f943df0c4dae0da7
--- /dev/null
+++ b/gligen/ldm/data/base.py
@@ -0,0 +1,23 @@
+from abc import abstractmethod
+from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
+
+
+class Txt2ImgIterableBaseDataset(IterableDataset):
+ '''
+ Define an interface to make the IterableDatasets for text2img data chainable
+ '''
+ def __init__(self, num_records=0, valid_ids=None, size=256):
+ super().__init__()
+ self.num_records = num_records
+ self.valid_ids = valid_ids
+ self.sample_ids = valid_ids
+ self.size = size
+
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
+
+ def __len__(self):
+ return self.num_records
+
+ @abstractmethod
+ def __iter__(self):
+ pass
\ No newline at end of file
diff --git a/gligen/ldm/data/imagenet.py b/gligen/ldm/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec7c09fe3e32a634cb47dbd8b5c23aaeb4580b4a
--- /dev/null
+++ b/gligen/ldm/data/imagenet.py
@@ -0,0 +1,394 @@
+import os, yaml, pickle, shutil, tarfile, glob
+import cv2
+import albumentations
+import PIL
+import numpy as np
+import torchvision.transforms.functional as TF
+from omegaconf import OmegaConf
+from functools import partial
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset, Subset
+
+import taming.data.utils as tdu
+from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
+from taming.data.imagenet import ImagePaths
+
+from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
+
+
+def synset2idx(path_to_yaml="ldm/data/index_synset.yaml"):
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ return dict((v,k) for k,v in di2s.items())
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config or OmegaConf.create()
+ if not type(self.config)==dict:
+ self.config = OmegaConf.to_container(self.config)
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._prepare_human_to_integer_label()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _prepare_human_to_integer_label(self):
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
+ if (not os.path.exists(self.human2integer)):
+ download(URL, self.human2integer)
+ with open(self.human2integer, "r") as f:
+ lines = f.read().splitlines()
+ assert len(lines) == 1000
+ self.human2integer_dict = dict()
+ for line in lines:
+ value, key = line.split(":")
+ self.human2integer_dict[key] = int(value)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ if not self.keep_orig_class_label:
+ self.class_labels = [class_dict[s] for s in self.synsets]
+ else:
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+
+ if self.process_images:
+ self.size = retrieve(self.config, "size", default=256)
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=self.size,
+ random_crop=self.random_crop,
+ )
+ else:
+ self.data = self.abspaths
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def __init__(self, process_images=True, data_root=None, **kwargs):
+ self.process_images = process_images
+ self.data_root = data_root
+ super().__init__(**kwargs)
+
+ def _prepare(self):
+ if self.data_root:
+ self.root = os.path.join(self.data_root, self.NAME)
+ else:
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
+ default=True)
+ if not tdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ tdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def __init__(self, process_images=True, data_root=None, **kwargs):
+ self.data_root = data_root
+ self.process_images = process_images
+ super().__init__(**kwargs)
+
+ def _prepare(self):
+ if self.data_root:
+ self.root = os.path.join(self.data_root, self.NAME)
+ else:
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
+ default=False)
+ if not tdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ tdu.mark_prepared(self.root)
+
+
+
+class ImageNetSR(Dataset):
+ def __init__(self, size=None,
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
+ random_crop=True):
+ """
+ Imagenet Superresolution Dataloader
+ Performs following ops in order:
+ 1. crops a crop of size s from image either as random or center crop
+ 2. resizes crop to size with cv2.area_interpolation
+ 3. degrades resized crop with degradation_fn
+
+ :param size: resizing to size after cropping
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
+ :param downscale_f: Low Resolution Downsample factor
+ :param min_crop_f: determines crop size s,
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
+ :param max_crop_f: ""
+ :param data_root:
+ :param random_crop:
+ """
+ self.base = self.get_base()
+ assert size
+ assert (size / downscale_f).is_integer()
+ self.size = size
+ self.LR_size = int(size / downscale_f)
+ self.min_crop_f = min_crop_f
+ self.max_crop_f = max_crop_f
+ assert(max_crop_f <= 1.)
+ self.center_crop = not random_crop
+
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
+
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
+
+ if degradation == "bsrgan":
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
+
+ elif degradation == "bsrgan_light":
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
+
+ else:
+ interpolation_fn = {
+ "cv_nearest": cv2.INTER_NEAREST,
+ "cv_bilinear": cv2.INTER_LINEAR,
+ "cv_bicubic": cv2.INTER_CUBIC,
+ "cv_area": cv2.INTER_AREA,
+ "cv_lanczos": cv2.INTER_LANCZOS4,
+ "pil_nearest": PIL.Image.NEAREST,
+ "pil_bilinear": PIL.Image.BILINEAR,
+ "pil_bicubic": PIL.Image.BICUBIC,
+ "pil_box": PIL.Image.BOX,
+ "pil_hamming": PIL.Image.HAMMING,
+ "pil_lanczos": PIL.Image.LANCZOS,
+ }[degradation]
+
+ self.pil_interpolation = degradation.startswith("pil_")
+
+ if self.pil_interpolation:
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
+
+ else:
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
+ interpolation=interpolation_fn)
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = Image.open(example["file_path_"])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ image = np.array(image).astype(np.uint8)
+
+ min_side_len = min(image.shape[:2])
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
+ crop_side_len = int(crop_side_len)
+
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
+
+ else:
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
+
+ image = self.cropper(image=image)["image"]
+ image = self.image_rescaler(image=image)["image"]
+
+ if self.pil_interpolation:
+ image_pil = PIL.Image.fromarray(image)
+ LR_image = self.degradation_process(image_pil)
+ LR_image = np.array(LR_image).astype(np.uint8)
+
+ else:
+ LR_image = self.degradation_process(image=image)["image"]
+
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
+
+ return example
+
+
+class ImageNetSRTrain(ImageNetSR):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ with open("ldm/data/imagenet_train_hr_indices.p", "rb") as f:
+ indices = pickle.load(f)
+ dset = ImageNetTrain(process_images=False,)
+ return Subset(dset, indices)
+
+
+class ImageNetSRValidation(ImageNetSR):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ with open("ldm/data/imagenet_val_hr_indices.p", "rb") as f:
+ indices = pickle.load(f)
+ dset = ImageNetValidation(process_images=False,)
+ return Subset(dset, indices)
diff --git a/gligen/ldm/data/imagenet_clsidx_to_label.txt b/gligen/ldm/data/imagenet_clsidx_to_label.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e2fe435526be7e0dd6675885c6c74b2f9276459b
--- /dev/null
+++ b/gligen/ldm/data/imagenet_clsidx_to_label.txt
@@ -0,0 +1,1000 @@
+ 0: 'tench, Tinca tinca',
+ 1: 'goldfish, Carassius auratus',
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
+ 3: 'tiger shark, Galeocerdo cuvieri',
+ 4: 'hammerhead, hammerhead shark',
+ 5: 'electric ray, crampfish, numbfish, torpedo',
+ 6: 'stingray',
+ 7: 'cock',
+ 8: 'hen',
+ 9: 'ostrich, Struthio camelus',
+ 10: 'brambling, Fringilla montifringilla',
+ 11: 'goldfinch, Carduelis carduelis',
+ 12: 'house finch, linnet, Carpodacus mexicanus',
+ 13: 'junco, snowbird',
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
+ 15: 'robin, American robin, Turdus migratorius',
+ 16: 'bulbul',
+ 17: 'jay',
+ 18: 'magpie',
+ 19: 'chickadee',
+ 20: 'water ouzel, dipper',
+ 21: 'kite',
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
+ 23: 'vulture',
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
+ 25: 'European fire salamander, Salamandra salamandra',
+ 26: 'common newt, Triturus vulgaris',
+ 27: 'eft',
+ 28: 'spotted salamander, Ambystoma maculatum',
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
+ 30: 'bullfrog, Rana catesbeiana',
+ 31: 'tree frog, tree-frog',
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
+ 35: 'mud turtle',
+ 36: 'terrapin',
+ 37: 'box turtle, box tortoise',
+ 38: 'banded gecko',
+ 39: 'common iguana, iguana, Iguana iguana',
+ 40: 'American chameleon, anole, Anolis carolinensis',
+ 41: 'whiptail, whiptail lizard',
+ 42: 'agama',
+ 43: 'frilled lizard, Chlamydosaurus kingi',
+ 44: 'alligator lizard',
+ 45: 'Gila monster, Heloderma suspectum',
+ 46: 'green lizard, Lacerta viridis',
+ 47: 'African chameleon, Chamaeleo chamaeleon',
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
+ 50: 'American alligator, Alligator mississipiensis',
+ 51: 'triceratops',
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
+ 53: 'ringneck snake, ring-necked snake, ring snake',
+ 54: 'hognose snake, puff adder, sand viper',
+ 55: 'green snake, grass snake',
+ 56: 'king snake, kingsnake',
+ 57: 'garter snake, grass snake',
+ 58: 'water snake',
+ 59: 'vine snake',
+ 60: 'night snake, Hypsiglena torquata',
+ 61: 'boa constrictor, Constrictor constrictor',
+ 62: 'rock python, rock snake, Python sebae',
+ 63: 'Indian cobra, Naja naja',
+ 64: 'green mamba',
+ 65: 'sea snake',
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
+ 69: 'trilobite',
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
+ 71: 'scorpion',
+ 72: 'black and gold garden spider, Argiope aurantia',
+ 73: 'barn spider, Araneus cavaticus',
+ 74: 'garden spider, Aranea diademata',
+ 75: 'black widow, Latrodectus mactans',
+ 76: 'tarantula',
+ 77: 'wolf spider, hunting spider',
+ 78: 'tick',
+ 79: 'centipede',
+ 80: 'black grouse',
+ 81: 'ptarmigan',
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
+ 84: 'peacock',
+ 85: 'quail',
+ 86: 'partridge',
+ 87: 'African grey, African gray, Psittacus erithacus',
+ 88: 'macaw',
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
+ 90: 'lorikeet',
+ 91: 'coucal',
+ 92: 'bee eater',
+ 93: 'hornbill',
+ 94: 'hummingbird',
+ 95: 'jacamar',
+ 96: 'toucan',
+ 97: 'drake',
+ 98: 'red-breasted merganser, Mergus serrator',
+ 99: 'goose',
+ 100: 'black swan, Cygnus atratus',
+ 101: 'tusker',
+ 102: 'echidna, spiny anteater, anteater',
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
+ 104: 'wallaby, brush kangaroo',
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
+ 106: 'wombat',
+ 107: 'jellyfish',
+ 108: 'sea anemone, anemone',
+ 109: 'brain coral',
+ 110: 'flatworm, platyhelminth',
+ 111: 'nematode, nematode worm, roundworm',
+ 112: 'conch',
+ 113: 'snail',
+ 114: 'slug',
+ 115: 'sea slug, nudibranch',
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
+ 118: 'Dungeness crab, Cancer magister',
+ 119: 'rock crab, Cancer irroratus',
+ 120: 'fiddler crab',
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
+ 125: 'hermit crab',
+ 126: 'isopod',
+ 127: 'white stork, Ciconia ciconia',
+ 128: 'black stork, Ciconia nigra',
+ 129: 'spoonbill',
+ 130: 'flamingo',
+ 131: 'little blue heron, Egretta caerulea',
+ 132: 'American egret, great white heron, Egretta albus',
+ 133: 'bittern',
+ 134: 'crane',
+ 135: 'limpkin, Aramus pictus',
+ 136: 'European gallinule, Porphyrio porphyrio',
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
+ 138: 'bustard',
+ 139: 'ruddy turnstone, Arenaria interpres',
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
+ 141: 'redshank, Tringa totanus',
+ 142: 'dowitcher',
+ 143: 'oystercatcher, oyster catcher',
+ 144: 'pelican',
+ 145: 'king penguin, Aptenodytes patagonica',
+ 146: 'albatross, mollymawk',
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
+ 149: 'dugong, Dugong dugon',
+ 150: 'sea lion',
+ 151: 'Chihuahua',
+ 152: 'Japanese spaniel',
+ 153: 'Maltese dog, Maltese terrier, Maltese',
+ 154: 'Pekinese, Pekingese, Peke',
+ 155: 'Shih-Tzu',
+ 156: 'Blenheim spaniel',
+ 157: 'papillon',
+ 158: 'toy terrier',
+ 159: 'Rhodesian ridgeback',
+ 160: 'Afghan hound, Afghan',
+ 161: 'basset, basset hound',
+ 162: 'beagle',
+ 163: 'bloodhound, sleuthhound',
+ 164: 'bluetick',
+ 165: 'black-and-tan coonhound',
+ 166: 'Walker hound, Walker foxhound',
+ 167: 'English foxhound',
+ 168: 'redbone',
+ 169: 'borzoi, Russian wolfhound',
+ 170: 'Irish wolfhound',
+ 171: 'Italian greyhound',
+ 172: 'whippet',
+ 173: 'Ibizan hound, Ibizan Podenco',
+ 174: 'Norwegian elkhound, elkhound',
+ 175: 'otterhound, otter hound',
+ 176: 'Saluki, gazelle hound',
+ 177: 'Scottish deerhound, deerhound',
+ 178: 'Weimaraner',
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
+ 181: 'Bedlington terrier',
+ 182: 'Border terrier',
+ 183: 'Kerry blue terrier',
+ 184: 'Irish terrier',
+ 185: 'Norfolk terrier',
+ 186: 'Norwich terrier',
+ 187: 'Yorkshire terrier',
+ 188: 'wire-haired fox terrier',
+ 189: 'Lakeland terrier',
+ 190: 'Sealyham terrier, Sealyham',
+ 191: 'Airedale, Airedale terrier',
+ 192: 'cairn, cairn terrier',
+ 193: 'Australian terrier',
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
+ 195: 'Boston bull, Boston terrier',
+ 196: 'miniature schnauzer',
+ 197: 'giant schnauzer',
+ 198: 'standard schnauzer',
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
+ 200: 'Tibetan terrier, chrysanthemum dog',
+ 201: 'silky terrier, Sydney silky',
+ 202: 'soft-coated wheaten terrier',
+ 203: 'West Highland white terrier',
+ 204: 'Lhasa, Lhasa apso',
+ 205: 'flat-coated retriever',
+ 206: 'curly-coated retriever',
+ 207: 'golden retriever',
+ 208: 'Labrador retriever',
+ 209: 'Chesapeake Bay retriever',
+ 210: 'German short-haired pointer',
+ 211: 'vizsla, Hungarian pointer',
+ 212: 'English setter',
+ 213: 'Irish setter, red setter',
+ 214: 'Gordon setter',
+ 215: 'Brittany spaniel',
+ 216: 'clumber, clumber spaniel',
+ 217: 'English springer, English springer spaniel',
+ 218: 'Welsh springer spaniel',
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
+ 220: 'Sussex spaniel',
+ 221: 'Irish water spaniel',
+ 222: 'kuvasz',
+ 223: 'schipperke',
+ 224: 'groenendael',
+ 225: 'malinois',
+ 226: 'briard',
+ 227: 'kelpie',
+ 228: 'komondor',
+ 229: 'Old English sheepdog, bobtail',
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
+ 231: 'collie',
+ 232: 'Border collie',
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
+ 234: 'Rottweiler',
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
+ 236: 'Doberman, Doberman pinscher',
+ 237: 'miniature pinscher',
+ 238: 'Greater Swiss Mountain dog',
+ 239: 'Bernese mountain dog',
+ 240: 'Appenzeller',
+ 241: 'EntleBucher',
+ 242: 'boxer',
+ 243: 'bull mastiff',
+ 244: 'Tibetan mastiff',
+ 245: 'French bulldog',
+ 246: 'Great Dane',
+ 247: 'Saint Bernard, St Bernard',
+ 248: 'Eskimo dog, husky',
+ 249: 'malamute, malemute, Alaskan malamute',
+ 250: 'Siberian husky',
+ 251: 'dalmatian, coach dog, carriage dog',
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
+ 253: 'basenji',
+ 254: 'pug, pug-dog',
+ 255: 'Leonberg',
+ 256: 'Newfoundland, Newfoundland dog',
+ 257: 'Great Pyrenees',
+ 258: 'Samoyed, Samoyede',
+ 259: 'Pomeranian',
+ 260: 'chow, chow chow',
+ 261: 'keeshond',
+ 262: 'Brabancon griffon',
+ 263: 'Pembroke, Pembroke Welsh corgi',
+ 264: 'Cardigan, Cardigan Welsh corgi',
+ 265: 'toy poodle',
+ 266: 'miniature poodle',
+ 267: 'standard poodle',
+ 268: 'Mexican hairless',
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
+ 273: 'dingo, warrigal, warragal, Canis dingo',
+ 274: 'dhole, Cuon alpinus',
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
+ 276: 'hyena, hyaena',
+ 277: 'red fox, Vulpes vulpes',
+ 278: 'kit fox, Vulpes macrotis',
+ 279: 'Arctic fox, white fox, Alopex lagopus',
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
+ 281: 'tabby, tabby cat',
+ 282: 'tiger cat',
+ 283: 'Persian cat',
+ 284: 'Siamese cat, Siamese',
+ 285: 'Egyptian cat',
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
+ 287: 'lynx, catamount',
+ 288: 'leopard, Panthera pardus',
+ 289: 'snow leopard, ounce, Panthera uncia',
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
+ 291: 'lion, king of beasts, Panthera leo',
+ 292: 'tiger, Panthera tigris',
+ 293: 'cheetah, chetah, Acinonyx jubatus',
+ 294: 'brown bear, bruin, Ursus arctos',
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
+ 298: 'mongoose',
+ 299: 'meerkat, mierkat',
+ 300: 'tiger beetle',
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
+ 302: 'ground beetle, carabid beetle',
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
+ 304: 'leaf beetle, chrysomelid',
+ 305: 'dung beetle',
+ 306: 'rhinoceros beetle',
+ 307: 'weevil',
+ 308: 'fly',
+ 309: 'bee',
+ 310: 'ant, emmet, pismire',
+ 311: 'grasshopper, hopper',
+ 312: 'cricket',
+ 313: 'walking stick, walkingstick, stick insect',
+ 314: 'cockroach, roach',
+ 315: 'mantis, mantid',
+ 316: 'cicada, cicala',
+ 317: 'leafhopper',
+ 318: 'lacewing, lacewing fly',
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
+ 320: 'damselfly',
+ 321: 'admiral',
+ 322: 'ringlet, ringlet butterfly',
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
+ 324: 'cabbage butterfly',
+ 325: 'sulphur butterfly, sulfur butterfly',
+ 326: 'lycaenid, lycaenid butterfly',
+ 327: 'starfish, sea star',
+ 328: 'sea urchin',
+ 329: 'sea cucumber, holothurian',
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
+ 331: 'hare',
+ 332: 'Angora, Angora rabbit',
+ 333: 'hamster',
+ 334: 'porcupine, hedgehog',
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
+ 336: 'marmot',
+ 337: 'beaver',
+ 338: 'guinea pig, Cavia cobaya',
+ 339: 'sorrel',
+ 340: 'zebra',
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
+ 342: 'wild boar, boar, Sus scrofa',
+ 343: 'warthog',
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
+ 345: 'ox',
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
+ 347: 'bison',
+ 348: 'ram, tup',
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
+ 350: 'ibex, Capra ibex',
+ 351: 'hartebeest',
+ 352: 'impala, Aepyceros melampus',
+ 353: 'gazelle',
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
+ 355: 'llama',
+ 356: 'weasel',
+ 357: 'mink',
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
+ 360: 'otter',
+ 361: 'skunk, polecat, wood pussy',
+ 362: 'badger',
+ 363: 'armadillo',
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
+ 366: 'gorilla, Gorilla gorilla',
+ 367: 'chimpanzee, chimp, Pan troglodytes',
+ 368: 'gibbon, Hylobates lar',
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
+ 370: 'guenon, guenon monkey',
+ 371: 'patas, hussar monkey, Erythrocebus patas',
+ 372: 'baboon',
+ 373: 'macaque',
+ 374: 'langur',
+ 375: 'colobus, colobus monkey',
+ 376: 'proboscis monkey, Nasalis larvatus',
+ 377: 'marmoset',
+ 378: 'capuchin, ringtail, Cebus capucinus',
+ 379: 'howler monkey, howler',
+ 380: 'titi, titi monkey',
+ 381: 'spider monkey, Ateles geoffroyi',
+ 382: 'squirrel monkey, Saimiri sciureus',
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
+ 385: 'Indian elephant, Elephas maximus',
+ 386: 'African elephant, Loxodonta africana',
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
+ 389: 'barracouta, snoek',
+ 390: 'eel',
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
+ 392: 'rock beauty, Holocanthus tricolor',
+ 393: 'anemone fish',
+ 394: 'sturgeon',
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
+ 396: 'lionfish',
+ 397: 'puffer, pufferfish, blowfish, globefish',
+ 398: 'abacus',
+ 399: 'abaya',
+ 400: "academic gown, academic robe, judge's robe",
+ 401: 'accordion, piano accordion, squeeze box',
+ 402: 'acoustic guitar',
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
+ 404: 'airliner',
+ 405: 'airship, dirigible',
+ 406: 'altar',
+ 407: 'ambulance',
+ 408: 'amphibian, amphibious vehicle',
+ 409: 'analog clock',
+ 410: 'apiary, bee house',
+ 411: 'apron',
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
+ 413: 'assault rifle, assault gun',
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
+ 415: 'bakery, bakeshop, bakehouse',
+ 416: 'balance beam, beam',
+ 417: 'balloon',
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
+ 419: 'Band Aid',
+ 420: 'banjo',
+ 421: 'bannister, banister, balustrade, balusters, handrail',
+ 422: 'barbell',
+ 423: 'barber chair',
+ 424: 'barbershop',
+ 425: 'barn',
+ 426: 'barometer',
+ 427: 'barrel, cask',
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
+ 429: 'baseball',
+ 430: 'basketball',
+ 431: 'bassinet',
+ 432: 'bassoon',
+ 433: 'bathing cap, swimming cap',
+ 434: 'bath towel',
+ 435: 'bathtub, bathing tub, bath, tub',
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
+ 437: 'beacon, lighthouse, beacon light, pharos',
+ 438: 'beaker',
+ 439: 'bearskin, busby, shako',
+ 440: 'beer bottle',
+ 441: 'beer glass',
+ 442: 'bell cote, bell cot',
+ 443: 'bib',
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
+ 445: 'bikini, two-piece',
+ 446: 'binder, ring-binder',
+ 447: 'binoculars, field glasses, opera glasses',
+ 448: 'birdhouse',
+ 449: 'boathouse',
+ 450: 'bobsled, bobsleigh, bob',
+ 451: 'bolo tie, bolo, bola tie, bola',
+ 452: 'bonnet, poke bonnet',
+ 453: 'bookcase',
+ 454: 'bookshop, bookstore, bookstall',
+ 455: 'bottlecap',
+ 456: 'bow',
+ 457: 'bow tie, bow-tie, bowtie',
+ 458: 'brass, memorial tablet, plaque',
+ 459: 'brassiere, bra, bandeau',
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
+ 461: 'breastplate, aegis, egis',
+ 462: 'broom',
+ 463: 'bucket, pail',
+ 464: 'buckle',
+ 465: 'bulletproof vest',
+ 466: 'bullet train, bullet',
+ 467: 'butcher shop, meat market',
+ 468: 'cab, hack, taxi, taxicab',
+ 469: 'caldron, cauldron',
+ 470: 'candle, taper, wax light',
+ 471: 'cannon',
+ 472: 'canoe',
+ 473: 'can opener, tin opener',
+ 474: 'cardigan',
+ 475: 'car mirror',
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
+ 477: "carpenter's kit, tool kit",
+ 478: 'carton',
+ 479: 'car wheel',
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
+ 481: 'cassette',
+ 482: 'cassette player',
+ 483: 'castle',
+ 484: 'catamaran',
+ 485: 'CD player',
+ 486: 'cello, violoncello',
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
+ 488: 'chain',
+ 489: 'chainlink fence',
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
+ 491: 'chain saw, chainsaw',
+ 492: 'chest',
+ 493: 'chiffonier, commode',
+ 494: 'chime, bell, gong',
+ 495: 'china cabinet, china closet',
+ 496: 'Christmas stocking',
+ 497: 'church, church building',
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
+ 499: 'cleaver, meat cleaver, chopper',
+ 500: 'cliff dwelling',
+ 501: 'cloak',
+ 502: 'clog, geta, patten, sabot',
+ 503: 'cocktail shaker',
+ 504: 'coffee mug',
+ 505: 'coffeepot',
+ 506: 'coil, spiral, volute, whorl, helix',
+ 507: 'combination lock',
+ 508: 'computer keyboard, keypad',
+ 509: 'confectionery, confectionary, candy store',
+ 510: 'container ship, containership, container vessel',
+ 511: 'convertible',
+ 512: 'corkscrew, bottle screw',
+ 513: 'cornet, horn, trumpet, trump',
+ 514: 'cowboy boot',
+ 515: 'cowboy hat, ten-gallon hat',
+ 516: 'cradle',
+ 517: 'crane',
+ 518: 'crash helmet',
+ 519: 'crate',
+ 520: 'crib, cot',
+ 521: 'Crock Pot',
+ 522: 'croquet ball',
+ 523: 'crutch',
+ 524: 'cuirass',
+ 525: 'dam, dike, dyke',
+ 526: 'desk',
+ 527: 'desktop computer',
+ 528: 'dial telephone, dial phone',
+ 529: 'diaper, nappy, napkin',
+ 530: 'digital clock',
+ 531: 'digital watch',
+ 532: 'dining table, board',
+ 533: 'dishrag, dishcloth',
+ 534: 'dishwasher, dish washer, dishwashing machine',
+ 535: 'disk brake, disc brake',
+ 536: 'dock, dockage, docking facility',
+ 537: 'dogsled, dog sled, dog sleigh',
+ 538: 'dome',
+ 539: 'doormat, welcome mat',
+ 540: 'drilling platform, offshore rig',
+ 541: 'drum, membranophone, tympan',
+ 542: 'drumstick',
+ 543: 'dumbbell',
+ 544: 'Dutch oven',
+ 545: 'electric fan, blower',
+ 546: 'electric guitar',
+ 547: 'electric locomotive',
+ 548: 'entertainment center',
+ 549: 'envelope',
+ 550: 'espresso maker',
+ 551: 'face powder',
+ 552: 'feather boa, boa',
+ 553: 'file, file cabinet, filing cabinet',
+ 554: 'fireboat',
+ 555: 'fire engine, fire truck',
+ 556: 'fire screen, fireguard',
+ 557: 'flagpole, flagstaff',
+ 558: 'flute, transverse flute',
+ 559: 'folding chair',
+ 560: 'football helmet',
+ 561: 'forklift',
+ 562: 'fountain',
+ 563: 'fountain pen',
+ 564: 'four-poster',
+ 565: 'freight car',
+ 566: 'French horn, horn',
+ 567: 'frying pan, frypan, skillet',
+ 568: 'fur coat',
+ 569: 'garbage truck, dustcart',
+ 570: 'gasmask, respirator, gas helmet',
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
+ 572: 'goblet',
+ 573: 'go-kart',
+ 574: 'golf ball',
+ 575: 'golfcart, golf cart',
+ 576: 'gondola',
+ 577: 'gong, tam-tam',
+ 578: 'gown',
+ 579: 'grand piano, grand',
+ 580: 'greenhouse, nursery, glasshouse',
+ 581: 'grille, radiator grille',
+ 582: 'grocery store, grocery, food market, market',
+ 583: 'guillotine',
+ 584: 'hair slide',
+ 585: 'hair spray',
+ 586: 'half track',
+ 587: 'hammer',
+ 588: 'hamper',
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
+ 590: 'hand-held computer, hand-held microcomputer',
+ 591: 'handkerchief, hankie, hanky, hankey',
+ 592: 'hard disc, hard disk, fixed disk',
+ 593: 'harmonica, mouth organ, harp, mouth harp',
+ 594: 'harp',
+ 595: 'harvester, reaper',
+ 596: 'hatchet',
+ 597: 'holster',
+ 598: 'home theater, home theatre',
+ 599: 'honeycomb',
+ 600: 'hook, claw',
+ 601: 'hoopskirt, crinoline',
+ 602: 'horizontal bar, high bar',
+ 603: 'horse cart, horse-cart',
+ 604: 'hourglass',
+ 605: 'iPod',
+ 606: 'iron, smoothing iron',
+ 607: "jack-o'-lantern",
+ 608: 'jean, blue jean, denim',
+ 609: 'jeep, landrover',
+ 610: 'jersey, T-shirt, tee shirt',
+ 611: 'jigsaw puzzle',
+ 612: 'jinrikisha, ricksha, rickshaw',
+ 613: 'joystick',
+ 614: 'kimono',
+ 615: 'knee pad',
+ 616: 'knot',
+ 617: 'lab coat, laboratory coat',
+ 618: 'ladle',
+ 619: 'lampshade, lamp shade',
+ 620: 'laptop, laptop computer',
+ 621: 'lawn mower, mower',
+ 622: 'lens cap, lens cover',
+ 623: 'letter opener, paper knife, paperknife',
+ 624: 'library',
+ 625: 'lifeboat',
+ 626: 'lighter, light, igniter, ignitor',
+ 627: 'limousine, limo',
+ 628: 'liner, ocean liner',
+ 629: 'lipstick, lip rouge',
+ 630: 'Loafer',
+ 631: 'lotion',
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
+ 633: "loupe, jeweler's loupe",
+ 634: 'lumbermill, sawmill',
+ 635: 'magnetic compass',
+ 636: 'mailbag, postbag',
+ 637: 'mailbox, letter box',
+ 638: 'maillot',
+ 639: 'maillot, tank suit',
+ 640: 'manhole cover',
+ 641: 'maraca',
+ 642: 'marimba, xylophone',
+ 643: 'mask',
+ 644: 'matchstick',
+ 645: 'maypole',
+ 646: 'maze, labyrinth',
+ 647: 'measuring cup',
+ 648: 'medicine chest, medicine cabinet',
+ 649: 'megalith, megalithic structure',
+ 650: 'microphone, mike',
+ 651: 'microwave, microwave oven',
+ 652: 'military uniform',
+ 653: 'milk can',
+ 654: 'minibus',
+ 655: 'miniskirt, mini',
+ 656: 'minivan',
+ 657: 'missile',
+ 658: 'mitten',
+ 659: 'mixing bowl',
+ 660: 'mobile home, manufactured home',
+ 661: 'Model T',
+ 662: 'modem',
+ 663: 'monastery',
+ 664: 'monitor',
+ 665: 'moped',
+ 666: 'mortar',
+ 667: 'mortarboard',
+ 668: 'mosque',
+ 669: 'mosquito net',
+ 670: 'motor scooter, scooter',
+ 671: 'mountain bike, all-terrain bike, off-roader',
+ 672: 'mountain tent',
+ 673: 'mouse, computer mouse',
+ 674: 'mousetrap',
+ 675: 'moving van',
+ 676: 'muzzle',
+ 677: 'nail',
+ 678: 'neck brace',
+ 679: 'necklace',
+ 680: 'nipple',
+ 681: 'notebook, notebook computer',
+ 682: 'obelisk',
+ 683: 'oboe, hautboy, hautbois',
+ 684: 'ocarina, sweet potato',
+ 685: 'odometer, hodometer, mileometer, milometer',
+ 686: 'oil filter',
+ 687: 'organ, pipe organ',
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
+ 689: 'overskirt',
+ 690: 'oxcart',
+ 691: 'oxygen mask',
+ 692: 'packet',
+ 693: 'paddle, boat paddle',
+ 694: 'paddlewheel, paddle wheel',
+ 695: 'padlock',
+ 696: 'paintbrush',
+ 697: "pajama, pyjama, pj's, jammies",
+ 698: 'palace',
+ 699: 'panpipe, pandean pipe, syrinx',
+ 700: 'paper towel',
+ 701: 'parachute, chute',
+ 702: 'parallel bars, bars',
+ 703: 'park bench',
+ 704: 'parking meter',
+ 705: 'passenger car, coach, carriage',
+ 706: 'patio, terrace',
+ 707: 'pay-phone, pay-station',
+ 708: 'pedestal, plinth, footstall',
+ 709: 'pencil box, pencil case',
+ 710: 'pencil sharpener',
+ 711: 'perfume, essence',
+ 712: 'Petri dish',
+ 713: 'photocopier',
+ 714: 'pick, plectrum, plectron',
+ 715: 'pickelhaube',
+ 716: 'picket fence, paling',
+ 717: 'pickup, pickup truck',
+ 718: 'pier',
+ 719: 'piggy bank, penny bank',
+ 720: 'pill bottle',
+ 721: 'pillow',
+ 722: 'ping-pong ball',
+ 723: 'pinwheel',
+ 724: 'pirate, pirate ship',
+ 725: 'pitcher, ewer',
+ 726: "plane, carpenter's plane, woodworking plane",
+ 727: 'planetarium',
+ 728: 'plastic bag',
+ 729: 'plate rack',
+ 730: 'plow, plough',
+ 731: "plunger, plumber's helper",
+ 732: 'Polaroid camera, Polaroid Land camera',
+ 733: 'pole',
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
+ 735: 'poncho',
+ 736: 'pool table, billiard table, snooker table',
+ 737: 'pop bottle, soda bottle',
+ 738: 'pot, flowerpot',
+ 739: "potter's wheel",
+ 740: 'power drill',
+ 741: 'prayer rug, prayer mat',
+ 742: 'printer',
+ 743: 'prison, prison house',
+ 744: 'projectile, missile',
+ 745: 'projector',
+ 746: 'puck, hockey puck',
+ 747: 'punching bag, punch bag, punching ball, punchball',
+ 748: 'purse',
+ 749: 'quill, quill pen',
+ 750: 'quilt, comforter, comfort, puff',
+ 751: 'racer, race car, racing car',
+ 752: 'racket, racquet',
+ 753: 'radiator',
+ 754: 'radio, wireless',
+ 755: 'radio telescope, radio reflector',
+ 756: 'rain barrel',
+ 757: 'recreational vehicle, RV, R.V.',
+ 758: 'reel',
+ 759: 'reflex camera',
+ 760: 'refrigerator, icebox',
+ 761: 'remote control, remote',
+ 762: 'restaurant, eating house, eating place, eatery',
+ 763: 'revolver, six-gun, six-shooter',
+ 764: 'rifle',
+ 765: 'rocking chair, rocker',
+ 766: 'rotisserie',
+ 767: 'rubber eraser, rubber, pencil eraser',
+ 768: 'rugby ball',
+ 769: 'rule, ruler',
+ 770: 'running shoe',
+ 771: 'safe',
+ 772: 'safety pin',
+ 773: 'saltshaker, salt shaker',
+ 774: 'sandal',
+ 775: 'sarong',
+ 776: 'sax, saxophone',
+ 777: 'scabbard',
+ 778: 'scale, weighing machine',
+ 779: 'school bus',
+ 780: 'schooner',
+ 781: 'scoreboard',
+ 782: 'screen, CRT screen',
+ 783: 'screw',
+ 784: 'screwdriver',
+ 785: 'seat belt, seatbelt',
+ 786: 'sewing machine',
+ 787: 'shield, buckler',
+ 788: 'shoe shop, shoe-shop, shoe store',
+ 789: 'shoji',
+ 790: 'shopping basket',
+ 791: 'shopping cart',
+ 792: 'shovel',
+ 793: 'shower cap',
+ 794: 'shower curtain',
+ 795: 'ski',
+ 796: 'ski mask',
+ 797: 'sleeping bag',
+ 798: 'slide rule, slipstick',
+ 799: 'sliding door',
+ 800: 'slot, one-armed bandit',
+ 801: 'snorkel',
+ 802: 'snowmobile',
+ 803: 'snowplow, snowplough',
+ 804: 'soap dispenser',
+ 805: 'soccer ball',
+ 806: 'sock',
+ 807: 'solar dish, solar collector, solar furnace',
+ 808: 'sombrero',
+ 809: 'soup bowl',
+ 810: 'space bar',
+ 811: 'space heater',
+ 812: 'space shuttle',
+ 813: 'spatula',
+ 814: 'speedboat',
+ 815: "spider web, spider's web",
+ 816: 'spindle',
+ 817: 'sports car, sport car',
+ 818: 'spotlight, spot',
+ 819: 'stage',
+ 820: 'steam locomotive',
+ 821: 'steel arch bridge',
+ 822: 'steel drum',
+ 823: 'stethoscope',
+ 824: 'stole',
+ 825: 'stone wall',
+ 826: 'stopwatch, stop watch',
+ 827: 'stove',
+ 828: 'strainer',
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
+ 830: 'stretcher',
+ 831: 'studio couch, day bed',
+ 832: 'stupa, tope',
+ 833: 'submarine, pigboat, sub, U-boat',
+ 834: 'suit, suit of clothes',
+ 835: 'sundial',
+ 836: 'sunglass',
+ 837: 'sunglasses, dark glasses, shades',
+ 838: 'sunscreen, sunblock, sun blocker',
+ 839: 'suspension bridge',
+ 840: 'swab, swob, mop',
+ 841: 'sweatshirt',
+ 842: 'swimming trunks, bathing trunks',
+ 843: 'swing',
+ 844: 'switch, electric switch, electrical switch',
+ 845: 'syringe',
+ 846: 'table lamp',
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
+ 848: 'tape player',
+ 849: 'teapot',
+ 850: 'teddy, teddy bear',
+ 851: 'television, television system',
+ 852: 'tennis ball',
+ 853: 'thatch, thatched roof',
+ 854: 'theater curtain, theatre curtain',
+ 855: 'thimble',
+ 856: 'thresher, thrasher, threshing machine',
+ 857: 'throne',
+ 858: 'tile roof',
+ 859: 'toaster',
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
+ 861: 'toilet seat',
+ 862: 'torch',
+ 863: 'totem pole',
+ 864: 'tow truck, tow car, wrecker',
+ 865: 'toyshop',
+ 866: 'tractor',
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
+ 868: 'tray',
+ 869: 'trench coat',
+ 870: 'tricycle, trike, velocipede',
+ 871: 'trimaran',
+ 872: 'tripod',
+ 873: 'triumphal arch',
+ 874: 'trolleybus, trolley coach, trackless trolley',
+ 875: 'trombone',
+ 876: 'tub, vat',
+ 877: 'turnstile',
+ 878: 'typewriter keyboard',
+ 879: 'umbrella',
+ 880: 'unicycle, monocycle',
+ 881: 'upright, upright piano',
+ 882: 'vacuum, vacuum cleaner',
+ 883: 'vase',
+ 884: 'vault',
+ 885: 'velvet',
+ 886: 'vending machine',
+ 887: 'vestment',
+ 888: 'viaduct',
+ 889: 'violin, fiddle',
+ 890: 'volleyball',
+ 891: 'waffle iron',
+ 892: 'wall clock',
+ 893: 'wallet, billfold, notecase, pocketbook',
+ 894: 'wardrobe, closet, press',
+ 895: 'warplane, military plane',
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
+ 897: 'washer, automatic washer, washing machine',
+ 898: 'water bottle',
+ 899: 'water jug',
+ 900: 'water tower',
+ 901: 'whiskey jug',
+ 902: 'whistle',
+ 903: 'wig',
+ 904: 'window screen',
+ 905: 'window shade',
+ 906: 'Windsor tie',
+ 907: 'wine bottle',
+ 908: 'wing',
+ 909: 'wok',
+ 910: 'wooden spoon',
+ 911: 'wool, woolen, woollen',
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
+ 913: 'wreck',
+ 914: 'yawl',
+ 915: 'yurt',
+ 916: 'web site, website, internet site, site',
+ 917: 'comic book',
+ 918: 'crossword puzzle, crossword',
+ 919: 'street sign',
+ 920: 'traffic light, traffic signal, stoplight',
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
+ 922: 'menu',
+ 923: 'plate',
+ 924: 'guacamole',
+ 925: 'consomme',
+ 926: 'hot pot, hotpot',
+ 927: 'trifle',
+ 928: 'ice cream, icecream',
+ 929: 'ice lolly, lolly, lollipop, popsicle',
+ 930: 'French loaf',
+ 931: 'bagel, beigel',
+ 932: 'pretzel',
+ 933: 'cheeseburger',
+ 934: 'hotdog, hot dog, red hot',
+ 935: 'mashed potato',
+ 936: 'head cabbage',
+ 937: 'broccoli',
+ 938: 'cauliflower',
+ 939: 'zucchini, courgette',
+ 940: 'spaghetti squash',
+ 941: 'acorn squash',
+ 942: 'butternut squash',
+ 943: 'cucumber, cuke',
+ 944: 'artichoke, globe artichoke',
+ 945: 'bell pepper',
+ 946: 'cardoon',
+ 947: 'mushroom',
+ 948: 'Granny Smith',
+ 949: 'strawberry',
+ 950: 'orange',
+ 951: 'lemon',
+ 952: 'fig',
+ 953: 'pineapple, ananas',
+ 954: 'banana',
+ 955: 'jackfruit, jak, jack',
+ 956: 'custard apple',
+ 957: 'pomegranate',
+ 958: 'hay',
+ 959: 'carbonara',
+ 960: 'chocolate sauce, chocolate syrup',
+ 961: 'dough',
+ 962: 'meat loaf, meatloaf',
+ 963: 'pizza, pizza pie',
+ 964: 'potpie',
+ 965: 'burrito',
+ 966: 'red wine',
+ 967: 'espresso',
+ 968: 'cup',
+ 969: 'eggnog',
+ 970: 'alp',
+ 971: 'bubble',
+ 972: 'cliff, drop, drop-off',
+ 973: 'coral reef',
+ 974: 'geyser',
+ 975: 'lakeside, lakeshore',
+ 976: 'promontory, headland, head, foreland',
+ 977: 'sandbar, sand bar',
+ 978: 'seashore, coast, seacoast, sea-coast',
+ 979: 'valley, vale',
+ 980: 'volcano',
+ 981: 'ballplayer, baseball player',
+ 982: 'groom, bridegroom',
+ 983: 'scuba diver',
+ 984: 'rapeseed',
+ 985: 'daisy',
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
+ 987: 'corn',
+ 988: 'acorn',
+ 989: 'hip, rose hip, rosehip',
+ 990: 'buckeye, horse chestnut, conker',
+ 991: 'coral fungus',
+ 992: 'agaric',
+ 993: 'gyromitra',
+ 994: 'stinkhorn, carrion fungus',
+ 995: 'earthstar',
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
+ 997: 'bolete',
+ 998: 'ear, spike, capitulum',
+ 999: 'toilet tissue, toilet paper, bathroom tissue'
\ No newline at end of file
diff --git a/gligen/ldm/data/imagenet_train_hr_indices.p b/gligen/ldm/data/imagenet_train_hr_indices.p
new file mode 100644
index 0000000000000000000000000000000000000000..f55f631aa0c1ae1a805896d42f133bacd3f7139b
--- /dev/null
+++ b/gligen/ldm/data/imagenet_train_hr_indices.p
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f86ea1924a1522b20bc0f709a069cc65f09d5fc617a7a31af7aaa3839a5a4d73
+size 132
diff --git a/gligen/ldm/data/imagenet_val_hr_indices.p b/gligen/ldm/data/imagenet_val_hr_indices.p
new file mode 100644
index 0000000000000000000000000000000000000000..93e8f10adc6c89e445f6b3f7af9d5c7d2c0da3df
--- /dev/null
+++ b/gligen/ldm/data/imagenet_val_hr_indices.p
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff1f5eb275a93c0fb53e227679f323ea1d024c87db296453296cebeef86fc0f4
+size 131
diff --git a/gligen/ldm/data/index_synset.yaml b/gligen/ldm/data/index_synset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..635ea71a0da40d42072fee110143520daa203ce6
--- /dev/null
+++ b/gligen/ldm/data/index_synset.yaml
@@ -0,0 +1,1000 @@
+0: n01440764
+1: n01443537
+2: n01484850
+3: n01491361
+4: n01494475
+5: n01496331
+6: n01498041
+7: n01514668
+8: n07646067
+9: n01518878
+10: n01530575
+11: n01531178
+12: n01532829
+13: n01534433
+14: n01537544
+15: n01558993
+16: n01560419
+17: n01580077
+18: n01582220
+19: n01592084
+20: n01601694
+21: n13382471
+22: n01614925
+23: n01616318
+24: n01622779
+25: n01629819
+26: n01630670
+27: n01631663
+28: n01632458
+29: n01632777
+30: n01641577
+31: n01644373
+32: n01644900
+33: n01664065
+34: n01665541
+35: n01667114
+36: n01667778
+37: n01669191
+38: n01675722
+39: n01677366
+40: n01682714
+41: n01685808
+42: n01687978
+43: n01688243
+44: n01689811
+45: n01692333
+46: n01693334
+47: n01694178
+48: n01695060
+49: n01697457
+50: n01698640
+51: n01704323
+52: n01728572
+53: n01728920
+54: n01729322
+55: n01729977
+56: n01734418
+57: n01735189
+58: n01737021
+59: n01739381
+60: n01740131
+61: n01742172
+62: n01744401
+63: n01748264
+64: n01749939
+65: n01751748
+66: n01753488
+67: n01755581
+68: n01756291
+69: n01768244
+70: n01770081
+71: n01770393
+72: n01773157
+73: n01773549
+74: n01773797
+75: n01774384
+76: n01774750
+77: n01775062
+78: n04432308
+79: n01784675
+80: n01795545
+81: n01796340
+82: n01797886
+83: n01798484
+84: n01806143
+85: n07647321
+86: n07647496
+87: n01817953
+88: n01818515
+89: n01819313
+90: n01820546
+91: n01824575
+92: n01828970
+93: n01829413
+94: n01833805
+95: n01843065
+96: n01843383
+97: n01847000
+98: n01855032
+99: n07646821
+100: n01860187
+101: n01871265
+102: n01872772
+103: n01873310
+104: n01877812
+105: n01882714
+106: n01883070
+107: n01910747
+108: n01914609
+109: n01917289
+110: n01924916
+111: n01930112
+112: n01943899
+113: n01944390
+114: n13719102
+115: n01950731
+116: n01955084
+117: n01968897
+118: n01978287
+119: n01978455
+120: n01980166
+121: n01981276
+122: n01983481
+123: n01984695
+124: n01985128
+125: n01986214
+126: n01990800
+127: n02002556
+128: n02002724
+129: n02006656
+130: n02007558
+131: n02009229
+132: n02009912
+133: n02011460
+134: n03126707
+135: n02013706
+136: n02017213
+137: n02018207
+138: n02018795
+139: n02025239
+140: n02027492
+141: n02028035
+142: n02033041
+143: n02037110
+144: n02051845
+145: n02056570
+146: n02058221
+147: n02066245
+148: n02071294
+149: n02074367
+150: n02077923
+151: n08742578
+152: n02085782
+153: n02085936
+154: n02086079
+155: n02086240
+156: n02086646
+157: n02086910
+158: n02087046
+159: n02087394
+160: n02088094
+161: n02088238
+162: n02088364
+163: n02088466
+164: n02088632
+165: n02089078
+166: n02089867
+167: n02089973
+168: n02090379
+169: n02090622
+170: n02090721
+171: n02091032
+172: n02091134
+173: n02091244
+174: n02091467
+175: n02091635
+176: n02091831
+177: n02092002
+178: n02092339
+179: n02093256
+180: n02093428
+181: n02093647
+182: n02093754
+183: n02093859
+184: n02093991
+185: n02094114
+186: n02094258
+187: n02094433
+188: n02095314
+189: n02095570
+190: n02095889
+191: n02096051
+192: n02096177
+193: n02096294
+194: n02096437
+195: n02096585
+196: n02097047
+197: n02097130
+198: n02097209
+199: n02097298
+200: n02097474
+201: n02097658
+202: n02098105
+203: n02098286
+204: n02098413
+205: n02099267
+206: n02099429
+207: n02099601
+208: n02099712
+209: n02099849
+210: n02100236
+211: n02100583
+212: n02100735
+213: n02100877
+214: n02101006
+215: n02101388
+216: n02101556
+217: n02102040
+218: n02102177
+219: n02102318
+220: n02102480
+221: n02102973
+222: n02104029
+223: n02104365
+224: n02105056
+225: n02105162
+226: n02105251
+227: n02105412
+228: n02105505
+229: n02105641
+230: n02105855
+231: n02106030
+232: n02106166
+233: n02106382
+234: n02106550
+235: n02106662
+236: n02107142
+237: n02107312
+238: n02107574
+239: n02107683
+240: n02107908
+241: n02108000
+242: n02108089
+243: n02108422
+244: n02108551
+245: n02108915
+246: n02109047
+247: n02109525
+248: n02109961
+249: n02110063
+250: n02110185
+251: n02110341
+252: n02110627
+253: n02110806
+254: n02110958
+255: n02111129
+256: n02111277
+257: n02111500
+258: n02111889
+259: n02112018
+260: n02112137
+261: n02112350
+262: n02112706
+263: n02113023
+264: n02113186
+265: n02113624
+266: n02113712
+267: n02113799
+268: n02113978
+269: n02114367
+270: n02114548
+271: n02114712
+272: n02114855
+273: n02115641
+274: n02115913
+275: n02116738
+276: n02117135
+277: n02119022
+278: n02119789
+279: n02120079
+280: n02120505
+281: n02123045
+282: n02123159
+283: n02123394
+284: n02123597
+285: n02124075
+286: n02125311
+287: n02127052
+288: n02128385
+289: n02128757
+290: n02128925
+291: n02129165
+292: n02129604
+293: n02130308
+294: n02132136
+295: n02133161
+296: n02134084
+297: n02134418
+298: n02137549
+299: n02138441
+300: n02165105
+301: n02165456
+302: n02167151
+303: n02168699
+304: n02169497
+305: n02172182
+306: n02174001
+307: n02177972
+308: n03373237
+309: n07975909
+310: n02219486
+311: n02226429
+312: n02229544
+313: n02231487
+314: n02233338
+315: n02236044
+316: n02256656
+317: n02259212
+318: n02264363
+319: n02268443
+320: n02268853
+321: n02276258
+322: n02277742
+323: n02279972
+324: n02280649
+325: n02281406
+326: n02281787
+327: n02317335
+328: n02319095
+329: n02321529
+330: n02325366
+331: n02326432
+332: n02328150
+333: n02342885
+334: n02346627
+335: n02356798
+336: n02361337
+337: n05262120
+338: n02364673
+339: n02389026
+340: n02391049
+341: n02395406
+342: n02396427
+343: n02397096
+344: n02398521
+345: n02403003
+346: n02408429
+347: n02410509
+348: n02412080
+349: n02415577
+350: n02417914
+351: n02422106
+352: n02422699
+353: n02423022
+354: n02437312
+355: n02437616
+356: n10771990
+357: n14765497
+358: n02443114
+359: n02443484
+360: n14765785
+361: n02445715
+362: n02447366
+363: n02454379
+364: n02457408
+365: n02480495
+366: n02480855
+367: n02481823
+368: n02483362
+369: n02483708
+370: n02484975
+371: n02486261
+372: n02486410
+373: n02487347
+374: n02488291
+375: n02488702
+376: n02489166
+377: n02490219
+378: n02492035
+379: n02492660
+380: n02493509
+381: n02493793
+382: n02494079
+383: n02497673
+384: n02500267
+385: n02504013
+386: n02504458
+387: n02509815
+388: n02510455
+389: n02514041
+390: n07783967
+391: n02536864
+392: n02606052
+393: n02607072
+394: n02640242
+395: n02641379
+396: n02643566
+397: n02655020
+398: n02666347
+399: n02667093
+400: n02669723
+401: n02672831
+402: n02676566
+403: n02687172
+404: n02690373
+405: n02692877
+406: n02699494
+407: n02701002
+408: n02704792
+409: n02708093
+410: n02727426
+411: n08496334
+412: n02747177
+413: n02749479
+414: n02769748
+415: n02776631
+416: n02777292
+417: n02782329
+418: n02783161
+419: n02786058
+420: n02787622
+421: n02788148
+422: n02790996
+423: n02791124
+424: n02791270
+425: n02793495
+426: n02794156
+427: n02795169
+428: n02797295
+429: n02799071
+430: n02802426
+431: n02804515
+432: n02804610
+433: n02807133
+434: n02808304
+435: n02808440
+436: n02814533
+437: n02814860
+438: n02815834
+439: n02817516
+440: n02823428
+441: n02823750
+442: n02825657
+443: n02834397
+444: n02835271
+445: n02837789
+446: n02840245
+447: n02841315
+448: n02843684
+449: n02859443
+450: n02860847
+451: n02865351
+452: n02869837
+453: n02870880
+454: n02871525
+455: n02877765
+456: n02880308
+457: n02883205
+458: n02892201
+459: n02892767
+460: n02894605
+461: n02895154
+462: n12520864
+463: n02909870
+464: n02910353
+465: n02916936
+466: n02917067
+467: n02927161
+468: n02930766
+469: n02939185
+470: n02948072
+471: n02950826
+472: n02951358
+473: n02951585
+474: n02963159
+475: n02965783
+476: n02966193
+477: n02966687
+478: n02971356
+479: n02974003
+480: n02977058
+481: n02978881
+482: n02979186
+483: n02980441
+484: n02981792
+485: n02988304
+486: n02992211
+487: n02992529
+488: n13652994
+489: n03000134
+490: n03000247
+491: n03000684
+492: n03014705
+493: n03016953
+494: n03017168
+495: n03018349
+496: n03026506
+497: n03028079
+498: n03032252
+499: n03041632
+500: n03042490
+501: n03045698
+502: n03047690
+503: n03062245
+504: n03063599
+505: n03063689
+506: n03065424
+507: n03075370
+508: n03085013
+509: n03089624
+510: n03095699
+511: n03100240
+512: n03109150
+513: n03110669
+514: n03124043
+515: n03124170
+516: n15142452
+517: n03126707
+518: n03127747
+519: n03127925
+520: n03131574
+521: n03133878
+522: n03134739
+523: n03141823
+524: n03146219
+525: n03160309
+526: n03179701
+527: n03180011
+528: n03187595
+529: n03188531
+530: n03196217
+531: n03197337
+532: n03201208
+533: n03207743
+534: n03207941
+535: n03208938
+536: n03216828
+537: n03218198
+538: n13872072
+539: n03223299
+540: n03240683
+541: n03249569
+542: n07647870
+543: n03255030
+544: n03259401
+545: n03271574
+546: n03272010
+547: n03272562
+548: n03290653
+549: n13869788
+550: n03297495
+551: n03314780
+552: n03325584
+553: n03337140
+554: n03344393
+555: n03345487
+556: n03347037
+557: n03355925
+558: n03372029
+559: n03376595
+560: n03379051
+561: n03384352
+562: n03388043
+563: n03388183
+564: n03388549
+565: n03393912
+566: n03394916
+567: n03400231
+568: n03404251
+569: n03417042
+570: n03424325
+571: n03425413
+572: n03443371
+573: n03444034
+574: n03445777
+575: n03445924
+576: n03447447
+577: n03447721
+578: n08286342
+579: n03452741
+580: n03457902
+581: n03459775
+582: n03461385
+583: n03467068
+584: n03476684
+585: n03476991
+586: n03478589
+587: n03482001
+588: n03482405
+589: n03483316
+590: n03485407
+591: n03485794
+592: n03492542
+593: n03494278
+594: n03495570
+595: n10161363
+596: n03498962
+597: n03527565
+598: n03529860
+599: n09218315
+600: n03532672
+601: n03534580
+602: n03535780
+603: n03538406
+604: n03544143
+605: n03584254
+606: n03584829
+607: n03590841
+608: n03594734
+609: n03594945
+610: n03595614
+611: n03598930
+612: n03599486
+613: n03602883
+614: n03617480
+615: n03623198
+616: n15102712
+617: n03630383
+618: n03633091
+619: n03637318
+620: n03642806
+621: n03649909
+622: n03657121
+623: n03658185
+624: n07977870
+625: n03662601
+626: n03666591
+627: n03670208
+628: n03673027
+629: n03676483
+630: n03680355
+631: n03690938
+632: n03691459
+633: n03692522
+634: n03697007
+635: n03706229
+636: n03709823
+637: n03710193
+638: n03710637
+639: n03710721
+640: n03717622
+641: n03720891
+642: n03721384
+643: n03725035
+644: n03729826
+645: n03733131
+646: n03733281
+647: n03733805
+648: n03742115
+649: n03743016
+650: n03759954
+651: n03761084
+652: n03763968
+653: n03764736
+654: n03769881
+655: n03770439
+656: n03770679
+657: n03773504
+658: n03775071
+659: n03775546
+660: n03776460
+661: n03777568
+662: n03777754
+663: n03781244
+664: n03782006
+665: n03785016
+666: n14955889
+667: n03787032
+668: n03788195
+669: n03788365
+670: n03791053
+671: n03792782
+672: n03792972
+673: n03793489
+674: n03794056
+675: n03796401
+676: n03803284
+677: n13652335
+678: n03814639
+679: n03814906
+680: n03825788
+681: n03832673
+682: n03837869
+683: n03838899
+684: n03840681
+685: n03841143
+686: n03843555
+687: n03854065
+688: n03857828
+689: n03866082
+690: n03868242
+691: n03868863
+692: n07281099
+693: n03873416
+694: n03874293
+695: n03874599
+696: n03876231
+697: n03877472
+698: n08053121
+699: n03884397
+700: n03887697
+701: n03888257
+702: n03888605
+703: n03891251
+704: n03891332
+705: n03895866
+706: n03899768
+707: n03902125
+708: n03903868
+709: n03908618
+710: n03908714
+711: n03916031
+712: n03920288
+713: n03924679
+714: n03929660
+715: n03929855
+716: n03930313
+717: n03930630
+718: n03934042
+719: n03935335
+720: n03937543
+721: n03938244
+722: n03942813
+723: n03944341
+724: n03947888
+725: n03950228
+726: n03954731
+727: n03956157
+728: n03958227
+729: n03961711
+730: n03967562
+731: n03970156
+732: n03976467
+733: n08620881
+734: n03977966
+735: n03980874
+736: n03982430
+737: n03983396
+738: n03991062
+739: n03992509
+740: n03995372
+741: n03998194
+742: n04004767
+743: n13937284
+744: n04008634
+745: n04009801
+746: n04019541
+747: n04023962
+748: n13413294
+749: n04033901
+750: n04033995
+751: n04037443
+752: n04039381
+753: n09403211
+754: n04041544
+755: n04044716
+756: n04049303
+757: n04065272
+758: n07056680
+759: n04069434
+760: n04070727
+761: n04074963
+762: n04081281
+763: n04086273
+764: n04090263
+765: n04099969
+766: n04111531
+767: n04116512
+768: n04118538
+769: n04118776
+770: n04120489
+771: n04125116
+772: n04127249
+773: n04131690
+774: n04133789
+775: n04136333
+776: n04141076
+777: n04141327
+778: n04141975
+779: n04146614
+780: n04147291
+781: n04149813
+782: n04152593
+783: n04154340
+784: n07917272
+785: n04162706
+786: n04179913
+787: n04192698
+788: n04200800
+789: n04201297
+790: n04204238
+791: n04204347
+792: n04208427
+793: n04209133
+794: n04209239
+795: n04228054
+796: n04229816
+797: n04235860
+798: n04238763
+799: n04239074
+800: n04243546
+801: n04251144
+802: n04252077
+803: n04252225
+804: n04254120
+805: n04254680
+806: n04254777
+807: n04258138
+808: n04259630
+809: n04263257
+810: n04264628
+811: n04265275
+812: n04266014
+813: n04270147
+814: n04273569
+815: n04275363
+816: n05605498
+817: n04285008
+818: n04286575
+819: n08646566
+820: n04310018
+821: n04311004
+822: n04311174
+823: n04317175
+824: n04325704
+825: n04326547
+826: n04328186
+827: n04330267
+828: n04332243
+829: n04335435
+830: n04337157
+831: n04344873
+832: n04346328
+833: n04347754
+834: n04350905
+835: n04355338
+836: n04355933
+837: n04356056
+838: n04357314
+839: n04366367
+840: n04367480
+841: n04370456
+842: n04371430
+843: n14009946
+844: n04372370
+845: n04376876
+846: n04380533
+847: n04389033
+848: n04392985
+849: n04398044
+850: n04399382
+851: n04404412
+852: n04409515
+853: n04417672
+854: n04418357
+855: n04423845
+856: n04428191
+857: n04429376
+858: n04435653
+859: n04442312
+860: n04443257
+861: n04447861
+862: n04456115
+863: n04458633
+864: n04461696
+865: n04462240
+866: n04465666
+867: n04467665
+868: n04476259
+869: n04479046
+870: n04482393
+871: n04483307
+872: n04485082
+873: n04486054
+874: n04487081
+875: n04487394
+876: n04493381
+877: n04501370
+878: n04505470
+879: n04507155
+880: n04509417
+881: n04515003
+882: n04517823
+883: n04522168
+884: n04523525
+885: n04525038
+886: n04525305
+887: n04532106
+888: n04532670
+889: n04536866
+890: n04540053
+891: n04542943
+892: n04548280
+893: n04548362
+894: n04550184
+895: n04552348
+896: n04553703
+897: n04554684
+898: n04557648
+899: n04560804
+900: n04562935
+901: n04579145
+902: n04579667
+903: n04584207
+904: n04589890
+905: n04590129
+906: n04591157
+907: n04591713
+908: n10782135
+909: n04596742
+910: n04598010
+911: n04599235
+912: n04604644
+913: n14423870
+914: n04612504
+915: n04613696
+916: n06359193
+917: n06596364
+918: n06785654
+919: n06794110
+920: n06874185
+921: n07248320
+922: n07565083
+923: n07657664
+924: n07583066
+925: n07584110
+926: n07590611
+927: n07613480
+928: n07614500
+929: n07615774
+930: n07684084
+931: n07693725
+932: n07695742
+933: n07697313
+934: n07697537
+935: n07711569
+936: n07714571
+937: n07714990
+938: n07715103
+939: n12159804
+940: n12160303
+941: n12160857
+942: n07717556
+943: n07718472
+944: n07718747
+945: n07720875
+946: n07730033
+947: n13001041
+948: n07742313
+949: n12630144
+950: n14991210
+951: n07749582
+952: n07753113
+953: n07753275
+954: n07753592
+955: n07754684
+956: n07760859
+957: n07768694
+958: n07802026
+959: n07831146
+960: n07836838
+961: n07860988
+962: n07871810
+963: n07873807
+964: n07875152
+965: n07880968
+966: n07892512
+967: n07920052
+968: n13904665
+969: n07932039
+970: n09193705
+971: n09229709
+972: n09246464
+973: n09256479
+974: n09288635
+975: n09332890
+976: n09399592
+977: n09421951
+978: n09428293
+979: n09468604
+980: n09472597
+981: n09835506
+982: n10148035
+983: n10565667
+984: n11879895
+985: n11939491
+986: n12057211
+987: n12144580
+988: n12267677
+989: n12620546
+990: n12768682
+991: n12985857
+992: n12998815
+993: n13037406
+994: n13040303
+995: n13044778
+996: n13052670
+997: n13054560
+998: n13133613
+999: n15075141
diff --git a/gligen/ldm/data/lsun.py b/gligen/ldm/data/lsun.py
new file mode 100644
index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e
--- /dev/null
+++ b/gligen/ldm/data/lsun.py
@@ -0,0 +1,92 @@
+import os
+import numpy as np
+import PIL
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+
+class LSUNBase(Dataset):
+ def __init__(self,
+ txt_file,
+ data_root,
+ size=None,
+ interpolation="bicubic",
+ flip_p=0.5
+ ):
+ self.data_paths = txt_file
+ self.data_root = data_root
+ with open(self.data_paths, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, l)
+ for l in self.image_paths],
+ }
+
+ self.size = size
+ self.interpolation = {"linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ }[interpolation]
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+ crop = min(img.shape[0], img.shape[1])
+ h, w, = img.shape[0], img.shape[1]
+ img = img[(h - crop) // 2:(h + crop) // 2,
+ (w - crop) // 2:(w + crop) // 2]
+
+ image = Image.fromarray(img)
+ if self.size is not None:
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip(image)
+ image = np.array(image).astype(np.uint8)
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
+ return example
+
+
+class LSUNChurchesTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
+
+
+class LSUNChurchesValidation(LSUNBase):
+ def __init__(self, flip_p=0., **kwargs):
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
+ flip_p=flip_p, **kwargs)
+
+
+class LSUNBedroomsTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
+
+
+class LSUNBedroomsValidation(LSUNBase):
+ def __init__(self, flip_p=0.0, **kwargs):
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
+ flip_p=flip_p, **kwargs)
+
+
+class LSUNCatsTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
+
+
+class LSUNCatsValidation(LSUNBase):
+ def __init__(self, flip_p=0., **kwargs):
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
+ flip_p=flip_p, **kwargs)
diff --git a/gligen/ldm/lr_scheduler.py b/gligen/ldm/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade
--- /dev/null
+++ b/gligen/ldm/lr_scheduler.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n,**kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi))
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
+
diff --git a/gligen/ldm/models/.DS_Store b/gligen/ldm/models/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..8e2da8e41b76fb8a3c71433582fcacba45e51b72
Binary files /dev/null and b/gligen/ldm/models/.DS_Store differ
diff --git a/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31267258e90660e79d7d00084135454ec92e8285
Binary files /dev/null and b/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ
diff --git a/gligen/ldm/models/autoencoder.py b/gligen/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1163e72dd063ee6773fe3e3c586c43b0663da4c9
--- /dev/null
+++ b/gligen/ldm/models/autoencoder.py
@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+#import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+# from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+
+
+
+
+class AutoencoderKL(nn.Module):
+ def __init__(self,
+ ddconfig,
+ embed_dim,
+ scale_factor=1
+ ):
+ super().__init__()
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ self.scale_factor = scale_factor
+
+
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior.sample() * self.scale_factor
+
+ def decode(self, z):
+ z = 1. / self.scale_factor * z
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+
+
+
+
+
+
+
diff --git a/gligen/ldm/models/diffusion/__init__.py b/gligen/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aab67efed99c0c65c5da44eab775d687302dcbd4
Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e496b8b850a5e43196b3aa6381f453e21f1d1766
Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ
diff --git a/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..967b19d2615c7aca3ade02c313bdf641867ed6a0
Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc differ
diff --git a/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07c6ce260940772e875980d5df10dfba907d352e
Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc differ
diff --git a/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0130f2932816a221829355e8f8fbf412e035960
Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc differ
diff --git a/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d450a9c0e02daf65da42ac35d0f564f98c894e8
Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc differ
diff --git a/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3067ac490bb3ebd63857301562bf694b559888fc
Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc differ
diff --git a/gligen/ldm/models/diffusion/classifier.py b/gligen/ldm/models/diffusion/classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef
--- /dev/null
+++ b/gligen/ldm/models/diffusion/classifier.py
@@ -0,0 +1,267 @@
+import os
+import torch
+import pytorch_lightning as pl
+from omegaconf import OmegaConf
+from torch.nn import functional as F
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import LambdaLR
+from copy import deepcopy
+from einops import rearrange
+from glob import glob
+from natsort import natsorted
+
+from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
+from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
+
+__models__ = {
+ 'class_label': EncoderUNetModel,
+ 'segmentation': UNetModel
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class NoisyLatentImageClassifier(pl.LightningModule):
+
+ def __init__(self,
+ diffusion_path,
+ num_classes,
+ ckpt_path=None,
+ pool='attention',
+ label_key=None,
+ diffusion_ckpt_path=None,
+ scheduler_config=None,
+ weight_decay=1.e-2,
+ log_steps=10,
+ monitor='val/loss',
+ *args,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.num_classes = num_classes
+ # get latest config of diffusion model
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
+ self.load_diffusion()
+
+ self.monitor = monitor
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
+ self.log_steps = log_steps
+
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
+ else self.diffusion_model.cond_stage_key
+
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
+
+ if self.label_key not in __models__:
+ raise NotImplementedError()
+
+ self.load_classifier(ckpt_path, pool)
+
+ self.scheduler_config = scheduler_config
+ self.use_scheduler = self.scheduler_config is not None
+ self.weight_decay = weight_decay
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def load_diffusion(self):
+ model = instantiate_from_config(self.diffusion_config)
+ self.diffusion_model = model.eval()
+ self.diffusion_model.train = disabled_train
+ for param in self.diffusion_model.parameters():
+ param.requires_grad = False
+
+ def load_classifier(self, ckpt_path, pool):
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
+ model_config.out_channels = self.num_classes
+ if self.label_key == 'class_label':
+ model_config.pool = pool
+
+ self.model = __models__[self.label_key](**model_config)
+ if ckpt_path is not None:
+ print('#####################################################################')
+ print(f'load from ckpt "{ckpt_path}"')
+ print('#####################################################################')
+ self.init_from_ckpt(ckpt_path)
+
+ @torch.no_grad()
+ def get_x_noisy(self, x, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x))
+ continuous_sqrt_alpha_cumprod = None
+ if self.diffusion_model.use_continuous_noise:
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
+ # todo: make sure t+1 is correct here
+
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
+
+ def forward(self, x_noisy, t, *args, **kwargs):
+ return self.model(x_noisy, t)
+
+ @torch.no_grad()
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ @torch.no_grad()
+ def get_conditioning(self, batch, k=None):
+ if k is None:
+ k = self.label_key
+ assert k is not None, 'Needs to provide label key'
+
+ targets = batch[k].to(self.device)
+
+ if self.label_key == 'segmentation':
+ targets = rearrange(targets, 'b h w c -> b c h w')
+ for down in range(self.numd):
+ h, w = targets.shape[-2:]
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
+
+ # targets = rearrange(targets,'b c h w -> b h w c')
+
+ return targets
+
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
+ _, top_ks = torch.topk(logits, k, dim=1)
+ if reduction == "mean":
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
+ elif reduction == "none":
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
+
+ def on_train_epoch_start(self):
+ # save some memory
+ self.diffusion_model.model.to('cpu')
+
+ @torch.no_grad()
+ def write_logs(self, loss, logits, targets):
+ log_prefix = 'train' if self.training else 'val'
+ log = {}
+ log[f"{log_prefix}/loss"] = loss.mean()
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
+ logits, targets, k=1, reduction="mean"
+ )
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
+ logits, targets, k=5, reduction="mean"
+ )
+
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
+
+ def shared_step(self, batch, t=None):
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
+ targets = self.get_conditioning(batch)
+ if targets.dim() == 4:
+ targets = targets.argmax(dim=1)
+ if t is None:
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
+ else:
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
+ x_noisy = self.get_x_noisy(x, t)
+ logits = self(x_noisy, t)
+
+ loss = F.cross_entropy(logits, targets, reduction='none')
+
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
+
+ loss = loss.mean()
+ return loss, logits, x_noisy, targets
+
+ def training_step(self, batch, batch_idx):
+ loss, *_ = self.shared_step(batch)
+ return loss
+
+ def reset_noise_accs(self):
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
+
+ def on_validation_start(self):
+ self.reset_noise_accs()
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ loss, *_ = self.shared_step(batch)
+
+ for t in self.noisy_acc:
+ _, logits, _, targets = self.shared_step(batch, t)
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
+
+ return loss
+
+ def configure_optimizers(self):
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
+
+ if self.use_scheduler:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [optimizer], scheduler
+
+ return optimizer
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, *args, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
+ log['inputs'] = x
+
+ y = self.get_conditioning(batch)
+
+ if self.label_key == 'class_label':
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ log['labels'] = y
+
+ if ismap(y):
+ log['labels'] = self.diffusion_model.to_rgb(y)
+
+ for step in range(self.log_steps):
+ current_time = step * self.log_time_interval
+
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
+
+ log[f'inputs@t{current_time}'] = x_noisy
+
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
+ pred = rearrange(pred, 'b h w c -> b c h w')
+
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
+
+ for key in log:
+ log[key] = log[key][:N]
+
+ return log
diff --git a/gligen/ldm/models/diffusion/ddim.py b/gligen/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..7db86661e94319b54bec15bf521097bb7b7faf87
--- /dev/null
+++ b/gligen/ldm/models/diffusion/ddim.py
@@ -0,0 +1,134 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+
+
+class DDIMSampler(object):
+ def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
+ super().__init__()
+ self.diffusion = diffusion
+ self.model = model
+ self.device = diffusion.betas.device
+ self.ddpm_num_timesteps = diffusion.num_timesteps
+ self.schedule = schedule
+ self.alpha_generator_func = alpha_generator_func
+ self.set_alpha_scale = set_alpha_scale
+
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ attr = attr.to(self.device)
+ setattr(self, name, attr)
+
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=False)
+ alphas_cumprod = self.diffusion.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
+
+ self.register_buffer('betas', to_torch(self.diffusion.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=False)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+
+ @torch.no_grad()
+ def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None):
+ self.make_schedule(ddim_num_steps=S)
+ return self.ddim_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0)
+
+
+ @torch.no_grad()
+ def ddim_sampling(self, shape, input, uc, guidance_scale=1, mask=None, x0=None):
+ b = shape[0]
+
+ img = input["x"]
+ if img == None:
+ img = torch.randn(shape, device=self.device)
+ input["x"] = img
+
+
+ time_range = np.flip(self.ddim_timesteps)
+ total_steps = self.ddim_timesteps.shape[0]
+
+ #iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+ iterator = time_range
+
+ if self.alpha_generator_func != None:
+ alphas = self.alpha_generator_func(len(iterator))
+
+
+ for i, step in enumerate(iterator):
+
+ # set alpha
+ if self.alpha_generator_func != None:
+ self.set_alpha_scale(self.model, alphas[i])
+ if alphas[i] == 0:
+ self.model.restore_first_conv_from_SD()
+
+ # run
+ index = total_steps - i - 1
+ input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.diffusion.q_sample( x0, input["timesteps"] )
+ img = img_orig * mask + (1. - mask) * img
+ input["x"] = img
+
+ img, pred_x0 = self.p_sample_ddim(input, index=index, uc=uc, guidance_scale=guidance_scale)
+ input["x"] = img
+
+ return img
+
+
+ @torch.no_grad()
+ def p_sample_ddim(self, input, index, uc=None, guidance_scale=1):
+
+
+ e_t = self.model(input)
+ if uc is not None and guidance_scale != 1:
+ unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=input["inpainting_extra_input"], grounding_extra_input=input['grounding_extra_input'])
+ e_t_uncond = self.model( unconditional_input )
+ e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
+
+ # select parameters corresponding to the currently considered timestep
+ b = input["x"].shape[0]
+ a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
+ a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
+ sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
+
+ # current prediction for x_0
+ pred_x0 = (input["x"] - sqrt_one_minus_at * e_t) / a_t.sqrt()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * torch.randn_like( input["x"] )
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+
+ return x_prev, pred_x0
diff --git a/gligen/ldm/models/diffusion/ddpm.py b/gligen/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e3feeabf55dbc0cf6fd112195bcebd7fddbec41
--- /dev/null
+++ b/gligen/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,72 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+from ldm.modules.diffusionmodules.util import make_beta_schedule
+
+
+
+
+
+class DDPM(nn.Module):
+ def __init__(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().__init__()
+
+ self.v_posterior = 0
+ self.register_schedule(beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+
+ def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/gligen/ldm/models/diffusion/gaussian_smoothing.py b/gligen/ldm/models/diffusion/gaussian_smoothing.py
new file mode 100644
index 0000000000000000000000000000000000000000..eec81e48b935ae1d3111f2c71d8d9c430bf8c19c
--- /dev/null
+++ b/gligen/ldm/models/diffusion/gaussian_smoothing.py
@@ -0,0 +1,119 @@
+import math
+import numbers
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class GaussianSmoothing(nn.Module):
+ """
+ Apply gaussian smoothing on a
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
+ in the input using a depthwise convolution.
+ Arguments:
+ channels (int, sequence): Number of channels of the input tensors. Output will
+ have this number of channels as well.
+ kernel_size (int, sequence): Size of the gaussian kernel.
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
+ dim (int, optional): The number of dimensions of the data.
+ Default value is 2 (spatial).
+ """
+ def __init__(self, channels, kernel_size, sigma, dim=2):
+ super(GaussianSmoothing, self).__init__()
+ if isinstance(kernel_size, numbers.Number):
+ kernel_size = [kernel_size] * dim
+ if isinstance(sigma, numbers.Number):
+ sigma = [sigma] * dim
+
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid(
+ [
+ torch.arange(size, dtype=torch.float32)
+ for size in kernel_size
+ ]
+ )
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
+ torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer('weight', kernel)
+ self.groups = channels
+
+ if dim == 1:
+ self.conv = F.conv1d
+ elif dim == 2:
+ self.conv = F.conv2d
+ elif dim == 3:
+ self.conv = F.conv3d
+ else:
+ raise RuntimeError(
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
+ )
+
+ def forward(self, input):
+ """
+ Apply gaussian filter to input.
+ Arguments:
+ input (torch.Tensor): Input to apply gaussian filter on.
+ Returns:
+ filtered (torch.Tensor): Filtered output.
+ """
+ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)
+
+
+class AverageSmoothing(nn.Module):
+ """
+ Apply average smoothing on a
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
+ in the input using a depthwise convolution.
+ Arguments:
+ channels (int, sequence): Number of channels of the input tensors. Output will
+ have this number of channels as well.
+ kernel_size (int, sequence): Size of the average kernel.
+ sigma (float, sequence): Standard deviation of the rage kernel.
+ dim (int, optional): The number of dimensions of the data.
+ Default value is 2 (spatial).
+ """
+ def __init__(self, channels, kernel_size, dim=2):
+ super(AverageSmoothing, self).__init__()
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size)
+
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer('weight', kernel)
+ self.groups = channels
+
+ if dim == 1:
+ self.conv = F.conv1d
+ elif dim == 2:
+ self.conv = F.conv2d
+ elif dim == 3:
+ self.conv = F.conv3d
+ else:
+ raise RuntimeError(
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
+ )
+
+ def forward(self, input):
+ """
+ Apply average filter to input.
+ Arguments:
+ input (torch.Tensor): Input to apply average filter on.
+ Returns:
+ filtered (torch.Tensor): Filtered output.
+ """
+ return self.conv(input, weight=self.weight, groups=self.groups)
diff --git a/gligen/ldm/models/diffusion/ldm.py b/gligen/ldm/models/diffusion/ldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..78fa65862d848a3fa49ff8c2b7bc475067175891
--- /dev/null
+++ b/gligen/ldm/models/diffusion/ldm.py
@@ -0,0 +1,88 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from tqdm import tqdm
+from ldm.util import default
+from ldm.modules.diffusionmodules.util import extract_into_tensor
+from .ddpm import DDPM
+
+
+
+class LatentDiffusion(DDPM):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # hardcoded
+ self.clip_denoised = False
+
+
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+
+ "Does not support DDPM sampling anymore. Only do DDIM or PLMS"
+
+ # = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = #
+
+ # def predict_start_from_noise(self, x_t, t, noise):
+ # return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ # extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise )
+
+ # def q_posterior(self, x_start, x_t, t):
+ # posterior_mean = (
+ # extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ # extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ # )
+ # posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ # posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ # return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+
+ # def p_mean_variance(self, model, x, c, t):
+
+ # model_out = model(x, t, c)
+ # x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+
+ # if self.clip_denoised:
+ # x_recon.clamp_(-1., 1.)
+
+ # model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ # return model_mean, posterior_variance, posterior_log_variance, x_recon
+
+
+ # @torch.no_grad()
+ # def p_sample(self, model, x, c, t):
+ # b, *_, device = *x.shape, x.device
+ # model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, )
+ # noise = torch.randn_like(x)
+
+ # # no noise when t == 0
+ # nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+
+
+ # @torch.no_grad()
+ # def p_sample_loop(self, model, shape, c):
+ # device = self.betas.device
+ # b = shape[0]
+ # img = torch.randn(shape, device=device)
+
+ # iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps)
+ # for i in iterator:
+ # ts = torch.full((b,), i, device=device, dtype=torch.long)
+ # img, x0 = self.p_sample(model, img, c, ts)
+
+ # return img
+
+
+ # @torch.no_grad()
+ # def sample(self, model, shape, c, uc=None, guidance_scale=None):
+ # return self.p_sample_loop(model, shape, c)
+
+
+
+
+
diff --git a/gligen/ldm/models/diffusion/loss.py b/gligen/ldm/models/diffusion/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..25e199219de4b6c3792d56a03e8d71450416151c
--- /dev/null
+++ b/gligen/ldm/models/diffusion/loss.py
@@ -0,0 +1,170 @@
+import math
+import torch
+from ldm.models.diffusion.gaussian_smoothing import GaussianSmoothing
+from torch.nn import functional as F
+from torchvision.utils import save_image
+
+
+
+
+
+
+def loss_one_att_outside(attn_map,bboxes, object_positions,t):
+ # loss = torch.tensor(0).to('cuda')
+ loss = 0
+ object_number = len(bboxes)
+ b, i, j = attn_map.shape
+ H = W = int(math.sqrt(i))
+
+
+ # if t== 20: import pdb; pdb.set_trace()
+
+ for obj_idx in range(object_number):
+
+ for obj_box in bboxes[obj_idx]:
+ mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
+ mask[y_min: y_max, x_min: x_max] = 1.
+ mask_out = 1. - mask
+ index = (mask == 1.).nonzero(as_tuple=False)
+ index_in_key = index[:,0]* H + index[:, 1]
+ att_box = torch.zeros_like(attn_map)
+ att_box[:,index_in_key,:] = attn_map[:,index_in_key,:]
+
+ att_box = att_box.sum(axis=1) / index_in_key.shape[0]
+ att_box = att_box.reshape(-1, H, H)
+ activation_value = (att_box* mask_out).reshape(b, -1).sum(dim=-1) #/ att_box.reshape(b, -1).sum(dim=-1)
+ loss += torch.mean(activation_value)
+
+ return loss / object_number
+
+def caculate_loss_self_att(self_first, self_second, self_third, bboxes, object_positions, t, list_res=[256], smooth_att = True,sigma=0.5,kernel_size=3 ):
+ all_attn = get_all_self_att(self_first, self_second, self_third)
+ cnt = 0
+ total_loss = 0
+ for res in list_res:
+ attn_maps = all_attn[res]
+ for attn in attn_maps:
+ total_loss += loss_one_att_outside(attn, bboxes, object_positions,t)
+ cnt += 1
+
+ return total_loss /cnt
+
+
+def get_all_self_att(self_first, self_second, self_third):
+ result = {256:[], 1024:[], 4096:[], 64:[], 94:[],1054:[] ,286:[],4126:[] }
+ # import pdb; pdb.set_trace()
+ all_att = [self_first, self_second, self_third]
+ for self_att in all_att:
+ for att in self_att:
+ if att != []:
+ temp = att[0]
+ for attn_map in temp:
+ current_res = attn_map.shape[1]
+ # print(current_res)
+ result[current_res].append(attn_map)
+ return result
+
+def get_all_attention(attn_maps_mid, attn_maps_up , attn_maps_down, res):
+ result = []
+
+ for attn_map_integrated in attn_maps_up:
+ if attn_map_integrated == []: continue
+ attn_map = attn_map_integrated[0][0]
+ b, i, j = attn_map.shape
+ H = W = int(math.sqrt(i))
+ # print(H)
+ if H == res:
+ result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
+ for attn_map_integrated in attn_maps_mid:
+
+ # for attn_map_integrated in attn_maps_mid:
+ attn_map = attn_map_integrated[0]
+ b, i, j = attn_map.shape
+ H = W = int(math.sqrt(i))
+ # print(H)
+ if (H==res):
+ result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
+ # import pdb; pdb.set_trace()
+ for attn_map_integrated in attn_maps_down:
+ if attn_map_integrated == []: continue
+ attn_map = attn_map_integrated[0][0]
+ if attn_map == []: continue
+ b, i, j = attn_map.shape
+ H = W = int(math.sqrt(i))
+ # print(H)
+ if (H==res):
+ result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
+
+ result = torch.cat(result, dim=0)
+ result = result.sum(0) / result.shape[0]
+ return result
+
+
+def caculate_loss_att_fixed_cnt(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ):
+ attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
+ # attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32)
+ # attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64)
+ # attn8 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 8)
+ all_attn = [attn16]
+ obj_number = len(bboxes)
+ total_loss = 0
+ # import pdb; pdb.set_trace()
+ for attn in all_attn[0:1]:
+ attn_text = attn[:, :, 1:-1]
+ attn_text *= 100
+ attn_text = torch.nn.functional.softmax(attn_text, dim=-1)
+ current_res = attn.shape[0]
+ H = W = current_res
+
+ # if t == 49: import pdb; pdb.set_trace()
+ for obj_idx in range(obj_number):
+ num_boxes= 0
+
+ for obj_position in object_positions[obj_idx]:
+ true_obj_position = obj_position - 1
+ att_map_obj = attn_text[:,:, true_obj_position]
+ if smooth_att:
+ smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda()
+ input = F.pad(att_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
+ att_map_obj = smoothing(input).squeeze(0).squeeze(0)
+ other_att_map_obj = att_map_obj.clone()
+ att_copy = att_map_obj.clone()
+
+ for obj_box in bboxes[obj_idx]:
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
+ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
+
+
+ if att_map_obj[y_min: y_max, x_min: x_max].numel() == 0:
+ max_inside=1.
+
+ else:
+ max_inside = att_map_obj[y_min: y_max, x_min: x_max].max()
+ total_loss += 1. - max_inside
+
+ # find max outside the box, find in the other boxes
+
+ att_copy[y_min: y_max, x_min: x_max] = 0.
+ other_att_map_obj[y_min: y_max, x_min: x_max] = 0.
+
+ for obj_outside in range(obj_number):
+ if obj_outside != obj_idx:
+ for obj_out_box in bboxes[obj_outside]:
+ x_min_out, y_min_out, x_max_out, y_max_out = int(obj_out_box[0] * W), \
+ int(obj_out_box[1] * H), int(obj_out_box[2] * W), int(obj_out_box[3] * H)
+
+ # att_copy[y_min: y_max, x_min: x_max] = 0.
+ if other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].numel() == 0:
+ max_outside_one= 0
+ else:
+ max_outside_one = other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].max()
+ # max_outside = max(max_outside,max_outside_one )
+ att_copy[y_min_out: y_max_out, x_min_out: x_max_out] = 0.
+ total_loss += max_outside_one
+ max_background = att_copy.max()
+ total_loss += len(bboxes[obj_idx]) *max_background /2.
+
+ return total_loss/obj_number
+
diff --git a/gligen/ldm/models/diffusion/plms.py b/gligen/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..d620128f7585cc8039102a80dd48a516d1415533
--- /dev/null
+++ b/gligen/ldm/models/diffusion/plms.py
@@ -0,0 +1,297 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+from copy import deepcopy
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+import math
+from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att
+class PLMSSampler(object):
+ def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
+ super().__init__()
+ self.diffusion = diffusion
+ self.model = model
+ self.device = diffusion.betas.device
+ self.ddpm_num_timesteps = diffusion.num_timesteps
+ self.schedule = schedule
+ self.alpha_generator_func = alpha_generator_func
+ self.set_alpha_scale = set_alpha_scale
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ attr = attr.to(self.device)
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.diffusion.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
+
+ self.register_buffer('betas', to_torch(self.diffusion.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+
+ # @torch.no_grad()
+ def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
+ self.make_schedule(ddim_num_steps=S)
+ # import pdb; pdb.set_trace()
+ return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
+
+
+ # @torch.no_grad()
+ def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
+
+ b = shape[0]
+
+ img = input["x"]
+ if img == None:
+ img = torch.randn(shape, device=self.device)
+ input["x"] = img
+
+ time_range = np.flip(self.ddim_timesteps)
+ total_steps = self.ddim_timesteps.shape[0]
+
+ old_eps = []
+
+ if self.alpha_generator_func != None:
+ alphas = self.alpha_generator_func(len(time_range))
+
+ for i, step in enumerate(time_range):
+
+ # set alpha and restore first conv layer
+ if self.alpha_generator_func != None:
+ self.set_alpha_scale(self.model, alphas[i])
+ if alphas[i] == 0:
+ self.model.restore_first_conv_from_SD()
+
+ # run
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=self.device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=self.device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.diffusion.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+ input["x"] = img
+ # three loss types
+ if loss_type !=None and loss_type!='standard':
+ if input['object_position'] != []:
+ if loss_type=='SAR_CAR':
+ x = self.update_loss_self_cross( input,i, index, ts )
+ elif loss_type=='SAR':
+ x = self.update_only_self( input,i, index, ts )
+ elif loss_type=='CAR':
+ x = self.update_loss_only_cross( input,i, index, ts )
+ input["x"] = x
+ img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
+ input["x"] = img
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+
+ return img
+
+ def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
+ if index1 < 10:
+ loss_scale = 4
+ max_iter = 1
+ elif index1 < 20:
+ loss_scale = 3
+ max_iter = 1
+ else:
+ loss_scale = 1
+ max_iter = 1
+
+ loss_threshold = 0.1
+ max_index = 30
+ x = deepcopy(input["x"])
+ iteration = 0
+ loss = torch.tensor(10000)
+ input["timesteps"] = ts
+
+ print("optimize", index1)
+ self.model.train()
+ while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
+ print('iter', iteration)
+ # import pdb; pdb.set_trace()
+ x = x.requires_grad_(True)
+ input['x'] = x
+ e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
+ bboxes = input['boxes_att']
+ object_positions = input['object_position']
+ loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
+ object_positions=object_positions, t = index1)*loss_scale
+ loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
+ object_positions=object_positions, t = index1)*loss_scale
+ loss = loss1 + loss2
+ print('loss', loss, loss1, loss2)
+ # hh = torch.autograd.backward(loss, retain_graph=True)
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [x])[0]
+ # grad_cond = x.grad
+ x = x - grad_cond
+ x = x.detach()
+ iteration += 1
+
+
+ return x
+
+ def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
+
+ if index1 < 10:
+ loss_scale = 3
+ max_iter = 5
+ elif index1 < 20:
+ loss_scale = 2
+ max_iter = 5
+ else:
+ loss_scale = 1
+ max_iter = 1
+ loss_threshold = 0.1
+
+ max_index = 30
+ x = deepcopy(input["x"])
+ iteration = 0
+ loss = torch.tensor(10000)
+ input["timesteps"] = ts
+
+ print("optimize", index1)
+ while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
+ print('iter', iteration)
+ x = x.requires_grad_(True)
+ input['x'] = x
+ e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
+
+ bboxes = input['boxes']
+ object_positions = input['object_position']
+ loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
+ object_positions=object_positions, t = index1)*loss_scale
+ loss = loss2
+ print('loss', loss)
+ hh = torch.autograd.backward(loss)
+ grad_cond = x.grad
+ x = x - grad_cond
+ x = x.detach()
+ iteration += 1
+ torch.cuda.empty_cache()
+ return x
+
+ def update_only_self(self, input,index1, index, ts,type_loss='self_accross' ):
+ if index1 < 10:
+ loss_scale = 4
+ max_iter = 5
+ elif index1 < 20:
+ loss_scale = 3
+ max_iter = 5
+ else:
+ loss_scale = 1
+ max_iter = 1
+ loss_threshold = 0.1
+
+ max_index = 30
+ x = deepcopy(input["x"])
+ iteration = 0
+ loss = torch.tensor(10000)
+ input["timesteps"] = ts
+
+ print("optimize", index1)
+ while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
+ print('iter', iteration)
+ x = x.requires_grad_(True)
+ input['x'] = x
+ e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
+
+ bboxes = input['boxes']
+ object_positions = input['object_position']
+ loss = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
+ object_positions=object_positions, t = index1)*loss_scale
+ print('loss', loss)
+ hh = torch.autograd.backward(loss)
+ grad_cond = x.grad
+
+ x = x - grad_cond
+ x = x.detach()
+ iteration += 1
+ torch.cuda.empty_cache()
+ return x
+
+ @torch.no_grad()
+ def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
+ x = deepcopy(input["x"])
+ b = x.shape[0]
+ self.model.eval()
+ def get_model_output(input):
+ e_t, first, second, third,_,_,_ = self.model(input)
+ if uc is not None and guidance_scale != 1:
+ unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=None, grounding_extra_input=None)
+ # unconditional_input=input
+ e_t_uncond, _, _, _, _, _, _ = self.model( unconditional_input)
+ e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
+ return e_t
+
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device)
+ a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device)
+ sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * torch.randn_like(x)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ input["timesteps"] = t
+ e_t = get_model_output(input)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ input["x"] = x_prev
+ input["timesteps"] = t_next
+ e_t_next = get_model_output(input)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
+
+
diff --git a/gligen/ldm/modules/.DS_Store b/gligen/ldm/modules/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..4c77b6cb34a492f9c0d376c771e131c8e8dc9388
Binary files /dev/null and b/gligen/ldm/modules/.DS_Store differ
diff --git a/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc b/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab45dbe365eb7a3fdccee25685d18676c93762b2
Binary files /dev/null and b/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc b/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..506b77ab3cb448eba85cb89937a96c6649eebf20
Binary files /dev/null and b/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/attention.py b/gligen/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..2147b3d23b1a1ecd539e741cff42b61c29476a97
--- /dev/null
+++ b/gligen/ldm/modules/attention.py
@@ -0,0 +1,431 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+# import configigure
+# from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+from torch.utils import checkpoint
+import os
+from torchvision.utils import save_image
+
+iter_att = 0
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, key_dim, value_dim, heads=8, dim_head=64, dropout=0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(key_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(value_dim, inner_dim, bias=False)
+
+
+ self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
+
+
+ def fill_inf_from_mask(self, sim, mask):
+ if mask is not None:
+ B,M = mask.shape
+ mask = mask.unsqueeze(1).repeat(1,self.heads,1).reshape(B*self.heads,1,-1)
+ max_neg_value = -torch.finfo(sim.dtype).max
+ sim.masked_fill_(~mask, max_neg_value)
+ return sim
+ # def scaled_dot_product(q, k, v, mask=None):
+ # d_k = q.size()[-1]
+ # attn_logits = torch.matmul(q, k.transpose(-2, -1))
+ # attn_logits = attn_logits / math.sqrt(d_k)
+ # if mask is not None:
+ # attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
+ # attention = F.softmax(attn_logits, dim=-1)
+ # values = torch.matmul(attention, v)
+ # return values, attention
+
+ def forward(self, x, key, value, mask=None):
+ # import pdb; pdb.set_trace()
+ q = self.to_q(x) # B*N*(H*C)
+ k = self.to_k(key) # B*M*(H*C)
+ v = self.to_v(value) # B*M*(H*C)
+
+ B, N, HC = q.shape
+ _, M, _ = key.shape
+ H = self.heads
+ C = HC // H
+
+ q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
+ k = k.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C
+ v = v.view(B,M,H,C).permute(0,2,1,3).reshape(B*H,M,C) # (B*H)*M*C
+
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale # (B*H)*N*M
+ self.fill_inf_from_mask(sim, mask)
+ attn = sim.softmax(dim=-1) # (B*H)*N*M
+ # import pdb; pdb.set_trace()
+ # if attn.shape[1] == 4096:
+ # self.visual_att(attn)
+ out = torch.einsum('b i j, b j d -> b i d', attn, v) # (B*H)*N*C
+ out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
+
+ return self.to_out(out), attn
+ def visual_att(self, att):
+ global iter_att
+ ll = [0,2,7]
+ for i in range(12):
+ kk = torch.sum(att[:,:,i], axis=0)
+ kk = kk.reshape(64,64)
+ save_image( (kk-kk.min()) / (kk.max() - kk.min()) , os.path.join('att', str(iter_att) + '_' +str(i) + '.png'))
+ iter_att += 1
+
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(query_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
+
+ def forward(self, x, gated=False):
+ q = self.to_q(x) # B*N*(H*C)
+ k = self.to_k(x) # B*N*(H*C)
+ v = self.to_v(x) # B*N*(H*C)
+
+ B, N, HC = q.shape
+ H = self.heads
+ C = HC // H
+ # if gated: import pdb; pdb.set_trace()
+ # import pdb; pdb.set_trace()
+ q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
+ k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
+ v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
+
+ sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N
+ attn = sim.softmax(dim=-1) # (B*H)*N*N
+ # if gated and attn.shape[1] == 4126:
+ # self.visual_att(attn)
+ out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C
+ out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
+
+ return self.to_out(out), attn
+
+ def visual_att(self, att):
+ global iter_att
+ ll = [0,2,7]
+ for i in range():
+ kk = torch.sum(att[i],axis=0)
+ kk = kk[:4096].reshape(64,64)
+ save_image( (kk-kk.min()) / (kk.max() - kk.min()) , os.path.join('att', str(iter_att) + '_' +str(i) + '.png'))
+ iter_att += 1
+
+
+class GatedCrossAttentionDense(nn.Module):
+ def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head):
+ super().__init__()
+
+ self.attn = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, glu=True)
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) )
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) )
+
+ # this can be useful: we can externally change magnitude of tanh(alpha)
+ # for example, when it is set to 0, then the entire model is same as original one
+ self.scale = 1
+
+ def forward(self, x, objs):
+
+ x = x + self.scale*torch.tanh(self.alpha_attn) * self.attn( self.norm1(x), objs, objs)
+ x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) )
+
+ return x
+
+
+class GatedSelfAttentionDense(nn.Module):
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, glu=True)
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) )
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) )
+
+ # this can be useful: we can externally change magnitude of tanh(alpha)
+ # for example, when it is set to 0, then the entire model is same as original one
+ self.scale = 1
+
+
+ def forward(self, x, objs,t):
+ # if t >300:
+ # self.scale = 1
+ # elif t > 200:
+ # self.scale = 0.9
+ # else:
+ # self.scale = 0.6
+ # if t >700:
+ # self.scale = 1
+ # elif t > 300:
+ # self.scale = 0.7
+ # else:
+ # self.scale = 0.4
+ # self.scale = 0
+
+ N_visual = x.shape[1]
+ objs = self.linear(objs)
+ out, grounding_att = self.attn( self.norm1(torch.cat([x,objs],dim=1)), True )
+ out = out[:,0:N_visual,:]
+ x = x + self.scale*torch.tanh(self.alpha_attn) * out
+ x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) )
+
+ return x , grounding_att
+
+
+
+
+
+
+class GatedSelfAttentionDense2(nn.Module):
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, glu=True)
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)) )
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)) )
+
+ # this can be useful: we can externally change magnitude of tanh(alpha)
+ # for example, when it is set to 0, then the entire model is same as original one
+ self.scale = 1
+
+
+ def forward(self, x, objs):
+
+ B, N_visual, _ = x.shape
+ B, N_ground, _ = objs.shape
+
+ objs = self.linear(objs)
+
+ # sanity check
+ size_v = math.sqrt(N_visual)
+ size_g = math.sqrt(N_ground)
+ assert int(size_v) == size_v, "Visual tokens must be square rootable"
+ assert int(size_g) == size_g, "Grounding tokens must be square rootable"
+ size_v = int(size_v)
+ size_g = int(size_g)
+
+ # select grounding token and resize it to visual token size as residual
+ out = self.attn( self.norm1(torch.cat([x,objs],dim=1)) )[:,N_visual:,:]
+ out = out.permute(0,2,1).reshape( B,-1,size_g,size_g )
+ out = torch.nn.functional.interpolate(out, (size_v,size_v), mode='bicubic')
+ residual = out.reshape(B,-1,N_visual).permute(0,2,1)
+
+ # add residual to visual feature
+ x = x + self.scale*torch.tanh(self.alpha_attn) * residual
+ x = x + self.scale*torch.tanh(self.alpha_dense) * self.ff( self.norm2(x) )
+
+ return x
+
+
+
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=True):
+ super().__init__()
+ self.attn1 = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, glu=True)
+ self.attn2 = CrossAttention(query_dim=query_dim, key_dim=key_dim, value_dim=value_dim, heads=n_heads, dim_head=d_head)
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+ self.norm3 = nn.LayerNorm(query_dim)
+ self.use_checkpoint = use_checkpoint
+
+ if fuser_type == "gatedSA":
+ # note key_dim here actually is context_dim
+ self.fuser = GatedSelfAttentionDense(query_dim, key_dim, n_heads, d_head)
+ elif fuser_type == "gatedSA2":
+ # note key_dim here actually is context_dim
+ self.fuser = GatedSelfAttentionDense2(query_dim, key_dim, n_heads, d_head)
+ elif fuser_type == "gatedCA":
+ self.fuser = GatedCrossAttentionDense(query_dim, key_dim, value_dim, n_heads, d_head)
+ else:
+ assert False
+
+
+ def forward(self, x, context, objs,t):
+# return checkpoint(self._forward, (x, context, objs), self.parameters(), self.use_checkpoint)
+ # import pdb; pdb.set_trace()
+ # if self.use_checkpoint and x.requires_grad:
+ # return checkpoint.checkpoint(self._forward, x, context, objs,t)
+ # else:
+ return self._forward(x, context, objs,t)
+
+ def _forward(self, x, context, objs,t):
+ # self_att_grounding = []
+ out, self_prob = self.attn1( self.norm1(x) )
+ x = x + out
+ x, self_prob_grounding = self.fuser(x, objs,t) # identity mapping in the beginning
+ x_1, prob = self.attn2(self.norm2(x), context, context)
+ x = x + x_1
+ x = self.ff(self.norm3(x)) + x
+ # self_att_grounding.append(self_prob)
+ # self_att_grounding.append(self_prob_grounding)
+ return x, prob, self_prob
+
+
+class SpatialTransformer(nn.Module):
+ def __init__(self, in_channels, key_dim, value_dim, n_heads, d_head, depth=1, fuser_type=None, use_checkpoint=True):
+ super().__init__()
+ self.in_channels = in_channels
+ query_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+
+ self.proj_in = nn.Conv2d(in_channels,
+ query_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(query_dim, key_dim, value_dim, n_heads, d_head, fuser_type, use_checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(query_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x, context, objs,t):
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ probs = []
+ self_prob_list = []
+ for block in self.transformer_blocks:
+ x, prob, self_prob = block(x, context, objs,t)
+ probs.append(prob)
+ self_prob_list.append(self_prob)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in, probs, self_prob_list
\ No newline at end of file
diff --git a/gligen/ldm/modules/diffusionmodules/__init__.py b/gligen/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93f5570e474602a9a3aea0ff5aa7e5c559b65b1d
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..801791333bcb0415e1f9fe39caf1c14d32b41019
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbb364cafadd2945347ab854337c8a7796e3f3af
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d2479f7e20192b033f24aef127c6a28b70944e9
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a23f0e3ce617bb29d18bf1b11247785f889758e
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e5e84dfe28e5bf086472a66bf1f7eaad5516221
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5df70a6186c395c6a2baaa716781fe7b0455e624
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b41b3506c39ddcac8985749922726679da227b57
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a62ce04191ef90b485b1d5d6f5cc22ffd279695c
Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/diffusionmodules/canny_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/canny_grounding_downsampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6331d15c76e0418a1e4a050d199727b53006ecfd
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/canny_grounding_downsampler.py
@@ -0,0 +1,31 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+
+
+
+class GroundingDownsampler(nn.Module):
+ def __init__(self, resize_input=256, out_dim=8):
+ super().__init__()
+ self.resize_input = resize_input
+ self.out_dim = out_dim
+
+ self.layers = nn.Sequential(
+ nn.Conv2d(1,4,4,2,1),
+ nn.SiLU(),
+ nn.Conv2d(4,self.out_dim,4,2,1)
+ )
+
+ def forward(self, grounding_extra_input):
+ # this is actually gary scale, but converted to rgb in dataset, information redudant
+ grounding_extra_input = grounding_extra_input[:,0].unsqueeze(1)
+
+ out = torch.nn.functional.interpolate(grounding_extra_input, (self.resize_input,self.resize_input), mode='bicubic')
+ out = self.layers(out)
+
+ assert out.shape[1] == self.out_dim
+ return out
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/canny_grounding_net.py b/gligen/ldm/modules/diffusionmodules/canny_grounding_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c8fcf7c64b387d99f067466a8a265082f805a88
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/canny_grounding_net.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+from ..attention import SelfAttention, FeedForward
+from .convnext import convnext_tiny
+
+
+
+
+class PositionNet(nn.Module):
+ def __init__(self, resize_input=448, out_dim=768):
+ super().__init__()
+ self.resize_input = resize_input
+ self.down_factor = 32 # determined by the convnext backbone
+ self.out_dim = out_dim
+ assert self.resize_input % self.down_factor == 0
+
+ self.convnext_tiny_backbone = convnext_tiny(pretrained=True)
+
+ self.num_tokens = (self.resize_input // self.down_factor) ** 2
+
+ convnext_feature_dim = 768
+ self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT
+
+ self.linears = nn.Sequential(
+ nn.Linear( convnext_feature_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim]))
+
+
+ def forward(self, canny_edge, mask):
+ B = canny_edge.shape[0]
+
+ # token from edge map
+ canny_edge = torch.nn.functional.interpolate(canny_edge, self.resize_input)
+ canny_edge_feature = self.convnext_tiny_backbone(canny_edge)
+ objs = canny_edge_feature.reshape(B, -1, self.num_tokens)
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
+
+ # expand null token
+ null_objs = self.null_feature.view(1,1,-1)
+ null_objs = null_objs.repeat(B,self.num_tokens,1)
+
+ # mask replacing
+ mask = mask.view(-1,1,1)
+ objs = objs*mask + null_objs*(1-mask)
+
+ # add pos
+ objs = objs + self.pos_embedding
+
+ # fuse them
+ objs = self.linears(objs)
+
+ assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim])
+ return objs
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/convnext.py b/gligen/ldm/modules/diffusionmodules/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..71956848b6631ecb7ae12b9d684e69e142a3ef45
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/convnext.py
@@ -0,0 +1,203 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.registry import register_model
+
+class Block(nn.Module):
+ r""" ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+class ConvNeXt(nn.Module):
+ r""" ConvNeXt
+ A PyTorch impl of : `A ConvNet for the 2020s` -
+ https://arxiv.org/pdf/2201.03545.pdf
+
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+ """
+ def __init__(self, in_chans=3, num_classes=1000,
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
+ layer_scale_init_value=1e-6, head_init_scale=1.,
+ ):
+ super().__init__()
+
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
+ )
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+ cur = 0
+ for i in range(4):
+ stage = nn.Sequential(
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ # self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
+ # self.head = nn.Linear(dims[-1], num_classes)
+
+ # self.apply(self._init_weights)
+ # self.head.weight.data.mul_(head_init_scale)
+ # self.head.bias.data.mul_(head_init_scale)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_features(self, x):
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+ return x
+ # return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ # x = self.head(x)
+ return x
+
+class LayerNorm(nn.Module):
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+model_urls = {
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
+}
+
+@register_model
+def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
+ if pretrained:
+ url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
+ model.load_state_dict(checkpoint["model"], strict=False) # we remove classifer head
+ return model
+
+@register_model
+def convnext_small(pretrained=False,in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
+ if pretrained:
+ url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+@register_model
+def convnext_base(pretrained=False, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
+ if pretrained:
+ url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+@register_model
+def convnext_large(pretrained=False, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
+ if pretrained:
+ url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+@register_model
+def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
+ if pretrained:
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
+ url = model_urls['convnext_xlarge_22k']
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model.load_state_dict(checkpoint["model"])
+ return model
\ No newline at end of file
diff --git a/gligen/ldm/modules/diffusionmodules/depth_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/depth_grounding_downsampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..80826ae2a96b615e24474f43b4ecdc9267049261
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/depth_grounding_downsampler.py
@@ -0,0 +1,32 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+
+
+
+class GroundingDownsampler(nn.Module):
+ def __init__(self, resize_input=256, out_dim=8):
+ super().__init__()
+ self.resize_input = resize_input
+ self.out_dim = out_dim
+
+ self.layers = nn.Sequential(
+ nn.Conv2d(1,4,4,2,1),
+ nn.SiLU(),
+ nn.Conv2d(4,self.out_dim,4,2,1)
+ )
+
+ def forward(self, grounding_extra_input):
+ # this is actually gary scale, but converted to rgb in dataset, information redudant
+
+ grounding_extra_input = grounding_extra_input[:,0].unsqueeze(1)
+
+ out = torch.nn.functional.interpolate(grounding_extra_input, (self.resize_input,self.resize_input), mode='bicubic')
+ out = self.layers(out)
+
+ assert out.shape[1] == self.out_dim
+ return out
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/depth_grounding_net.py b/gligen/ldm/modules/diffusionmodules/depth_grounding_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..637816e79a97e38cf987e6311fa91d9792dc0fce
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/depth_grounding_net.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+from ..attention import SelfAttention, FeedForward
+from .convnext import convnext_tiny
+
+
+
+
+class PositionNet(nn.Module):
+ def __init__(self, resize_input=448, out_dim=768):
+ super().__init__()
+ self.resize_input = resize_input
+ self.down_factor = 32 # determined by the convnext backbone
+ self.out_dim = out_dim
+ assert self.resize_input % self.down_factor == 0
+
+ self.convnext_tiny_backbone = convnext_tiny(pretrained=True)
+
+ self.num_tokens = (self.resize_input // self.down_factor) ** 2
+
+ convnext_feature_dim = 768
+ self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT
+
+ self.linears = nn.Sequential(
+ nn.Linear( convnext_feature_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim]))
+
+
+ def forward(self, depth, mask):
+ B = depth.shape[0]
+
+ # token from edge map
+ depth = torch.nn.functional.interpolate(depth, self.resize_input)
+ depth_feature = self.convnext_tiny_backbone(depth)
+ objs = depth_feature.reshape(B, -1, self.num_tokens)
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
+
+ # expand null token
+ null_objs = self.null_feature.view(1,1,-1)
+ null_objs = null_objs.repeat(B,self.num_tokens,1)
+
+ # mask replacing
+ mask = mask.view(-1,1,1)
+ objs = objs*mask + null_objs*(1-mask)
+
+ # add pos
+ objs = objs + self.pos_embedding
+
+ # fuse them
+ objs = self.linears(objs)
+
+ assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim])
+ return objs
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/grounding_net_example.py b/gligen/ldm/modules/diffusionmodules/grounding_net_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a09caf5e48bb11f789236a4c34bdbd9ee6cabee
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/grounding_net_example.py
@@ -0,0 +1,22 @@
+"""
+This is a high-level pseudo code for grounding net.
+
+This class needs to tokenize grounding input into gronding tokens which
+will be used in GatedAttenion layers.
+
+
+class PositionNet(nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ kwargs should be defined by model.grounding_tokenizer in config yaml file.
+
+ def forward(self, **kwargs):
+
+ kwargs should be the output of grounding_tokenizer_input network
+
+ return grounding_tokens # with shape: Batch * Num_Of_Token* Token_Channel_Dimension
+
+
+
+"""
\ No newline at end of file
diff --git a/gligen/ldm/modules/diffusionmodules/hed_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/hed_grounding_downsampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..99d1e7def372b74db331b6f50d3dc4574ace47a2
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/hed_grounding_downsampler.py
@@ -0,0 +1,23 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+
+
+
+class GroundingDownsampler(nn.Module):
+ def __init__(self, out_dim=1):
+ super().__init__()
+ self.out_dim = out_dim
+ # No learnable params for hed edge map, just downsample it with bicubic
+
+ def forward(self, grounding_extra_input):
+ # this is actually gary scale, but converted to rgb in dataset, information redudant
+ grounding_extra_input = grounding_extra_input[:,0].unsqueeze(1)
+
+ out = torch.nn.functional.interpolate(grounding_extra_input, (64,64), mode='bicubic')
+ assert out.shape[1] == self.out_dim
+ return out
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/hed_grounding_net.py b/gligen/ldm/modules/diffusionmodules/hed_grounding_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..e566bb35c914abd19e51c8661d54a8702c3d55df
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/hed_grounding_net.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+from ..attention import SelfAttention, FeedForward
+from .convnext import convnext_tiny
+
+
+
+
+class PositionNet(nn.Module):
+ def __init__(self, resize_input=448, out_dim=768):
+ super().__init__()
+ self.resize_input = resize_input
+ self.down_factor = 32 # determined by the res50 backbone
+ self.out_dim = out_dim
+ assert self.resize_input % self.down_factor == 0
+
+ self.convnext_tiny_backbone = convnext_tiny(pretrained=True)
+
+ self.num_tokens = (self.resize_input // self.down_factor) ** 2
+
+ convnext_feature_dim = 768
+ self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT
+
+ self.linears = nn.Sequential(
+ nn.Linear( convnext_feature_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim]))
+
+
+ def forward(self, hed_edge, mask):
+ B = hed_edge.shape[0]
+
+ # token from edge map
+ hed_edge = torch.nn.functional.interpolate(hed_edge, self.resize_input)
+ hed_edge_feature = self.convnext_tiny_backbone(hed_edge)
+ objs = hed_edge_feature.reshape(B, -1, self.num_tokens)
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
+
+ # expand null token
+ null_objs = self.null_feature.view(1,1,-1)
+ null_objs = null_objs.repeat(B,self.num_tokens,1)
+
+ # mask replacing
+ mask = mask.view(-1,1,1)
+ objs = objs*mask + null_objs*(1-mask)
+
+ # add pos
+ objs = objs + self.pos_embedding
+
+ # fuse them
+ objs = self.linears(objs)
+
+ assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim])
+ return objs
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/keypoint_grounding_net.py b/gligen/ldm/modules/diffusionmodules/keypoint_grounding_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9da67a713c917ac8aedf09ba2803421550021e3
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/keypoint_grounding_net.py
@@ -0,0 +1,61 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+
+
+
+class PositionNet(nn.Module):
+ def __init__(self, max_persons_per_image, out_dim, fourier_freqs=8):
+ super().__init__()
+ self.max_persons_per_image = max_persons_per_image
+ self.out_dim = out_dim
+
+ self.person_embeddings = torch.nn.Parameter(torch.zeros([max_persons_per_image,out_dim]))
+ self.keypoint_embeddings = torch.nn.Parameter(torch.zeros([17,out_dim]))
+
+
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
+ self.position_dim = fourier_freqs*2*2 # 2 is sin&cos, 2 is xy
+
+ self.linears = nn.Sequential(
+ nn.Linear( self.out_dim + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ self.null_person_feature = torch.nn.Parameter(torch.zeros([self.out_dim]))
+ self.null_xy_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
+
+
+ def forward(self, points, masks):
+
+ masks = masks.unsqueeze(-1)
+ N = points.shape[0]
+
+ person_embeddings = self.person_embeddings.unsqueeze(1).repeat(1,17,1).reshape(self.max_persons_per_image*17, self.out_dim)
+ keypoint_embeddings = torch.cat([self.keypoint_embeddings]*self.max_persons_per_image, dim=0)
+ person_embeddings = person_embeddings + keypoint_embeddings # (num_person*17) * C
+ person_embeddings = person_embeddings.unsqueeze(0).repeat(N,1,1)
+
+ # embedding position (it may includes padding as placeholder)
+ xy_embedding = self.fourier_embedder(points) # B*N*2 --> B*N*C
+
+
+ # learnable null embedding
+ person_null = self.null_person_feature.view(1,1,-1)
+ xy_null = self.null_xy_feature.view(1,1,-1)
+
+ # replace padding with learnable null embedding
+ person_embeddings = person_embeddings*masks + (1-masks)*person_null
+ xy_embedding = xy_embedding*masks + (1-masks)*xy_null
+
+ objs = self.linears( torch.cat([person_embeddings, xy_embedding], dim=-1) )
+
+ return objs
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/model.py b/gligen/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..533e589a2024f1d7c52093d8c472c3b1b6617e26
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,835 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from ldm.util import instantiate_from_config
+from ldm.modules.attention import LinearAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self, ch_mult:list, in_channels,
+ pretrained_model:nn.Module=None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
+ stride=1,padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+
+ @torch.no_grad()
+ def encode_with_pretrained(self,x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self,x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model,self.downsampler):
+ z = submodel(z,temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z,'b c h w -> b (h w) c')
+ return z
+
diff --git a/gligen/ldm/modules/diffusionmodules/normal_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/normal_grounding_downsampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b663401b253e09bd3f1cd78e725373c1f537b4f8
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/normal_grounding_downsampler.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+
+
+
+class GroundingDownsampler(nn.Module):
+ def __init__(self, resize_input=256, out_dim=8):
+ super().__init__()
+ self.resize_input = resize_input
+ self.out_dim = out_dim
+
+ self.layers = nn.Sequential(
+ nn.Conv2d(3,4,4,2,1),
+ nn.SiLU(),
+ nn.Conv2d(4,self.out_dim,4,2,1)
+ )
+
+ def forward(self, grounding_extra_input):
+
+ out = torch.nn.functional.interpolate(grounding_extra_input, (self.resize_input,self.resize_input), mode='bicubic')
+ out = self.layers(out)
+
+ assert out.shape[1] == self.out_dim
+ return out
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/normal_grounding_net.py b/gligen/ldm/modules/diffusionmodules/normal_grounding_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..38cadb7c9321f2d4aacabf3a4e31ac2207bebb32
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/normal_grounding_net.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+from ..attention import SelfAttention, FeedForward
+from .convnext import convnext_tiny
+
+
+
+
+class PositionNet(nn.Module):
+ def __init__(self, resize_input=448, out_dim=768):
+ super().__init__()
+ self.resize_input = resize_input
+ self.down_factor = 32 # determined by the convnext backbone
+ self.out_dim = out_dim
+ assert self.resize_input % self.down_factor == 0
+
+ self.convnext_tiny_backbone = convnext_tiny(pretrained=True)
+
+ self.num_tokens = (self.resize_input // self.down_factor) ** 2
+
+ convnext_feature_dim = 768
+ self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT
+
+ self.linears = nn.Sequential(
+ nn.Linear( convnext_feature_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim]))
+
+
+ def forward(self, normal, mask):
+ B = normal.shape[0]
+
+ # token from edge map
+ normal = torch.nn.functional.interpolate(normal, self.resize_input)
+ normal_feature = self.convnext_tiny_backbone(normal)
+ objs = normal_feature.reshape(B, -1, self.num_tokens)
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
+
+ # expand null token
+ null_objs = self.null_feature.view(1,1,-1)
+ null_objs = null_objs.repeat(B,self.num_tokens,1)
+
+ # mask replacing
+ mask = mask.view(-1,1,1)
+ objs = objs*mask + null_objs*(1-mask)
+
+ # add pos
+ objs = objs + self.pos_embedding
+
+ # fuse them
+ objs = self.linears(objs)
+
+ assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim])
+ return objs
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/openaimodel.py b/gligen/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..34e39ea3f9d8ab58055beb26783d14d047878a5a
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,562 @@
+from abc import abstractmethod
+from functools import partial
+import math
+
+import numpy as np
+import random
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+# from .positionnet import PositionNet
+from torch.utils import checkpoint
+from ldm.util import instantiate_from_config
+from copy import deepcopy
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context, objs,t):
+ probs = []
+ self_prob_list = []
+
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x, prob, self_prob = layer(x, context, objs,t)
+ probs.append(prob)
+ self_prob_list.append(self_prob)
+ else:
+ x = layer(x)
+ return x, probs, self_prob_list
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ # return checkpoint(
+ # self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ # )
+ # if self.use_checkpoint and x.requires_grad:
+ # return checkpoint.checkpoint(self._forward, x, emb )
+ # else:
+ return self._forward(x, emb)
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+
+
+class UNetModel(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ num_heads=8,
+ use_scale_shift_norm=False,
+ transformer_depth=1,
+ positive_len = 768,
+ context_dim=None,
+ fuser_type = None,
+ is_inpaint = False,
+ is_style = False,
+ grounding_downsampler = None,
+
+
+ ):
+ super().__init__()
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.num_heads = num_heads
+ self.context_dim = context_dim
+ self.fuser_type = fuser_type
+ self.is_inpaint = is_inpaint
+ self.positive_len = positive_len
+ assert fuser_type in ["gatedSA","gatedSA2","gatedCA"]
+
+ self.grounding_tokenizer_input = None # set externally
+
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+
+
+ self.downsample_net = None
+ self.additional_channel_from_downsampler = 0
+ self.first_conv_type = "SD"
+ self.first_conv_restorable = True
+ if grounding_downsampler is not None:
+ self.downsample_net = instantiate_from_config(grounding_downsampler)
+ self.additional_channel_from_downsampler = self.downsample_net.out_dim
+ self.first_conv_type = "GLIGEN"
+
+ if is_inpaint:
+ # The new added channels are: masked image (encoded image) and mask, which is 4+1
+ in_c = in_channels+self.additional_channel_from_downsampler+in_channels+1
+ self.first_conv_restorable = False # in inpaint; You must use extra channels to take in masked real image
+ else:
+ in_c = in_channels+self.additional_channel_from_downsampler
+ self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_c, model_channels, 3, padding=1))])
+
+
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+
+ # = = = = = = = = = = = = = = = = = = = = Down Branch = = = = = = = = = = = = = = = = = = = = #
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [ ResBlock(ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,) ]
+
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ dim_head = ch // num_heads
+ layers.append(SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint))
+
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ input_block_chans.append(ch)
+
+ if level != len(channel_mult) - 1: # will not go to this downsample branch in the last feature
+ out_ch = ch
+ self.input_blocks.append( TimestepEmbedSequential( Downsample(ch, conv_resample, dims=dims, out_channels=out_ch ) ) )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ dim_head = ch // num_heads
+
+ # self.input_blocks = [ C | RT RT D | RT RT D | RT RT D | R R ]
+
+
+ # = = = = = = = = = = = = = = = = = = = = BottleNeck = = = = = = = = = = = = = = = = = = = = #
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm),
+ SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint),
+ ResBlock(ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm))
+
+
+
+ # = = = = = = = = = = = = = = = = = = = = Up Branch = = = = = = = = = = = = = = = = = = = = #
+
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [ ResBlock(ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm) ]
+ ch = model_channels * mult
+
+ if ds in attention_resolutions:
+ dim_head = ch // num_heads
+ layers.append( SpatialTransformer(ch, key_dim=context_dim, value_dim=context_dim, n_heads=num_heads, d_head=dim_head, depth=transformer_depth, fuser_type=fuser_type, use_checkpoint=use_checkpoint) )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append( Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) )
+ ds //= 2
+
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+
+
+ # self.output_blocks = [ R R RU | RT RT RTU | RT RT RTU | RT RT RT ]
+
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ # self.position_net = instantiate_from_config(grounding_tokenizer)
+ from .text_grounding_net import PositionNet
+ self.position_net = PositionNet(in_dim=positive_len, out_dim=context_dim)
+
+
+
+
+
+ def restore_first_conv_from_SD(self):
+ if self.first_conv_restorable:
+ device = self.input_blocks[0][0].weight.device
+
+ SD_weights = th.load("gligen/SD_input_conv_weight_bias.pth")
+ self.GLIGEN_first_conv_state_dict = deepcopy(self.input_blocks[0][0].state_dict())
+
+ self.input_blocks[0][0] = conv_nd(2, 4, 320, 3, padding=1)
+ self.input_blocks[0][0].load_state_dict(SD_weights)
+ self.input_blocks[0][0].to(device)
+
+ self.first_conv_type = "SD"
+ else:
+ print("First conv layer is not restorable and skipped this process, probably because this is an inpainting model?")
+
+
+ def restore_first_conv_from_GLIGEN(self):
+ breakpoint() # TODO
+
+
+ def forward_position_net(self,input):
+ # import pdb; pdb.set_trace()
+ if ("boxes" in input):
+ boxes, masks, text_embeddings = input["boxes"], input["masks"], input["text_embeddings"]
+ _ , self.max_box, _ = text_embeddings.shape
+ else:
+ dtype = input["x"].dtype
+ batch = input["x"].shape[0]
+ device = input["x"].device
+ boxes = th.zeros(batch, self.max_box, 4,).type(dtype).to(device)
+ masks = th.zeros(batch, self.max_box).type(dtype).to(device)
+ text_embeddings = th.zeros(batch, self.max_box, self.positive_len).type(dtype).to(device)
+ if self.training and random.random() < 0.1: # random drop for guidance
+ boxes, masks, text_embeddings = boxes*0, masks*0, text_embeddings*0
+
+ objs = self.position_net( boxes, masks, text_embeddings ) # B*N*C
+
+ return objs
+
+ def forward_position_net_with_image(self,input):
+
+ if ("boxes" in input):
+ boxes = input["boxes"]
+ masks = input["masks"]
+ text_masks = input["text_masks"]
+ image_masks = input["image_masks"]
+ text_embeddings = input["text_embeddings"]
+ image_embeddings = input["image_embeddings"]
+ _ , self.max_box, _ = text_embeddings.shape
+ else:
+ dtype = input["x"].dtype
+ batch = input["x"].shape[0]
+ device = input["x"].device
+ boxes = th.zeros(batch, self.max_box, 4,).type(dtype).to(device)
+ masks = th.zeros(batch, self.max_box).type(dtype).to(device)
+ text_masks = th.zeros(batch, self.max_box).type(dtype).to(device)
+ image_masks = th.zeros(batch, self.max_box).type(dtype).to(device)
+ text_embeddings = th.zeros(batch, self.max_box, self.positive_len).type(dtype).to(device)
+ image_embeddings = th.zeros(batch, self.max_box, self.positive_len).type(dtype).to(device)
+
+ if self.training and random.random() < 0.1: # random drop for guidance
+ boxes = boxes*0
+ masks = masks*0
+ text_masks = text_masks*0
+ image_masks = image_masks*0
+ text_embeddings = text_embeddings*0
+ image_embeddings = image_embeddings*0
+
+ objs = self.position_net( boxes, masks, text_masks, image_masks, text_embeddings, image_embeddings ) # B*N*C
+
+ return objs
+
+
+ def forward(self, input,unc=False):
+
+ if ("boxes" in input):
+ # grounding_input = input["grounding_input"]
+ boxes, masks, text_embeddings = input["boxes"], input["masks"], input["text_embeddings"]
+ _ , self.max_box, _ = text_embeddings.shape
+ else:
+ # Guidance null case
+ # grounding_input = self.grounding_tokenizer_input.get_null_input()
+ # boxes, masks, text_embeddings = input["boxes"]*0, input["masks"]*0, input["text_embeddings"]*0
+ dtype = input["x"].dtype
+ batch = input["x"].shape[0]
+ device = input["x"].device
+ boxes = th.zeros(batch, self.max_box, 4,).type(dtype).to(device)
+ masks = th.zeros(batch, self.max_box).type(dtype).to(device)
+ text_masks = th.zeros(batch, self.max_box).type(dtype).to(device)
+ image_masks = th.zeros(batch, self.max_box).type(dtype).to(device)
+ text_embeddings = th.zeros(batch, self.max_box, self.positive_len).type(dtype).to(device)
+ image_embeddings = th.zeros(batch, self.max_box, self.positive_len).type(dtype).to(device)
+
+ if self.training and random.random() < 0.1 : # random drop for guidance
+ boxes, masks, text_embeddings = boxes*0, masks*0, text_embeddings*0
+
+ objs = self.position_net( boxes, masks, text_embeddings )
+
+ # Time embedding
+
+ t_emb = timestep_embedding(input["timesteps"], self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ # input tensor
+ h = input["x"]
+ t = input["timesteps"]
+ if self.downsample_net != None and self.first_conv_type=="GLIGEN":
+ temp = self.downsample_net(input["grounding_extra_input"])
+ h = th.cat( [h,temp], dim=1 )
+ if self.is_inpaint:#self.inpaint_mode:
+ if self.downsample_net != None:
+ breakpoint() # TODO: think about this case
+ h = th.cat( [h, input["inpainting_extra_input"]], dim=1 )
+
+ # Text input
+ context = input["context"]
+
+ # Start forwarding
+ hs = []
+ probs_first = []
+ self_prob_list_first = []
+
+ for module in self.input_blocks:
+ h,prob, self_prob = module(h, emb, context, objs,t)
+ hs.append(h)
+ probs_first.append(prob)
+ self_prob_list_first.append(self_prob)
+
+ h,mid_prob, self_prob_list_second = self.middle_block(h, emb, context, objs,t)
+
+ probs_third = []
+ self_prob_list_third = []
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h, prob, self_prob = module(h, emb, context, objs,t)
+ probs_third.append(prob)
+ self_prob_list_third.append(self_prob)
+
+ return self.out(h),probs_third , mid_prob, probs_first, self_prob_list_first, [self_prob_list_second], self_prob_list_third
+
+
+
+
+
+
+
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/pseudo_example.py b/gligen/ldm/modules/diffusionmodules/pseudo_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ba5014e9e8a6e71538232d86cba46e098110c0e
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/pseudo_example.py
@@ -0,0 +1,52 @@
+"""
+This is a high-level pseudo code for grounding net.
+
+This class needs to tokenize grounding input into gronding tokens which
+will be used in GatedAttenion layers.
+
+
+class PositionNet(nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ kwargs should be defined by model.grounding_tokenizer in config yaml file.
+
+ def forward(self, **kwargs):
+
+ kwargs should be the output of grounding_tokenizer_input network
+
+ return grounding_tokens # with shape: Batch * Num_Of_Token* Token_Channel_Dimension
+
+
+
+"""
+
+
+# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
+
+
+"""
+This is a high-level pseudo code for downsampler.
+
+This class needs to process input and output a spatial feature such that it will be
+fed into the first conv layer.
+
+
+class GroundingDownsampler(nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ kwargs should be defined by model.grounding_downsampler in config yaml file.
+
+ you MUST define self.out_dim such that Unet knows add how many extra layers
+
+
+ def forward(self, **kwargs):
+
+ kwargs should be the output of grounding_downsampler_input network
+
+ return spatial_feature # with shape: Batch * self.out_dim * H *W (64*64 for SD)
+
+
+
+"""
\ No newline at end of file
diff --git a/gligen/ldm/modules/diffusionmodules/resnet.py b/gligen/ldm/modules/diffusionmodules/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8ce07516fef99554c51c58fb3379448cf89154f
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/resnet.py
@@ -0,0 +1,337 @@
+import torch
+from torch import Tensor
+import torch.nn as nn
+from typing import Type, Any, Callable, Union, List, Optional
+
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(ResNet, self).__init__()
+ print("Please manually decide which layer as output")
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
+ stride: int = 1, dilate: bool = False) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x):
+ # The comment resolution is based on input size is 224*224
+ out = {}
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ out['f0'] = x # N*64*56*56
+
+ x = self.layer1(x)
+ out['f1'] = x # N*64*56*56
+
+ x = self.layer2(x)
+ out['f2'] = x # N*128*28*28
+
+ x = self.layer3(x)
+ out['f3'] = x # N*256*14*14
+
+ x = self.layer4(x)
+ out['f4'] = x # N*512*7*7
+ return x
+
+
+ # x = self.avgpool(x)
+ # x = torch.flatten(x, 1)
+ # out['penultimate'] = x # N*512
+
+ # x = self.fc(x)
+ # out['logits'] = x # N*1000
+
+ # return out
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+
+def _resnet(
+ arch: str,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ pretrained: bool,
+ progress: bool,
+ **kwargs: Any
+) -> ResNet:
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
+ model.load_state_dict(state_dict, strict=False) # we remove fc, and only keep backbone
+ return model
+
+
+def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
+
+
+def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
+
+
+def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
+
+
+def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
+
+
+def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
+
diff --git a/gligen/ldm/modules/diffusionmodules/sem_grounding_downsampler.py b/gligen/ldm/modules/diffusionmodules/sem_grounding_downsampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..10e4fc09a40bb912860cd3186b4970cabd2a0938
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/sem_grounding_downsampler.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+
+
+
+class GroundingDownsampler(nn.Module):
+ def __init__(self, resize_input=256, in_dim=152, out_dim=8):
+ super().__init__()
+ self.resize_input = resize_input
+ self.out_dim = out_dim
+
+ self.layers = nn.Sequential(
+ nn.Conv2d(in_dim,16,4,2,1),
+ nn.SiLU(),
+ nn.Conv2d(16,self.out_dim,4,2,1)
+ )
+
+ def forward(self, grounding_extra_input):
+
+ out = torch.nn.functional.interpolate(grounding_extra_input, (self.resize_input,self.resize_input), mode='nearest')
+ out = self.layers(out)
+
+ assert out.shape[1] == self.out_dim
+ return out
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/sem_grounding_net.py b/gligen/ldm/modules/diffusionmodules/sem_grounding_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..80ef6dd49dc5bcc58205913276c39982b07320b9
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/sem_grounding_net.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+from ..attention import SelfAttention, FeedForward
+from .convnext import convnext_tiny
+
+
+
+
+class PositionNet(nn.Module):
+ def __init__(self, resize_input=448, in_dim=152, out_dim=768):
+ super().__init__()
+
+ self.resize_input = resize_input
+ self.down_factor = 32 # determined by the convnext backbone
+ self.out_dim = out_dim
+ assert self.resize_input % self.down_factor == 0
+
+ self.in_conv = nn.Conv2d(in_dim,3,3,1,1) # from num_sem to 3 channels
+ self.convnext_tiny_backbone = convnext_tiny(pretrained=True)
+
+ self.num_tokens = (self.resize_input // self.down_factor) ** 2
+
+ convnext_feature_dim = 768
+ self.pos_embedding = nn.Parameter(torch.empty(1, self.num_tokens, convnext_feature_dim).normal_(std=0.02)) # from BERT
+
+ self.linears = nn.Sequential(
+ nn.Linear( convnext_feature_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ self.null_feature = torch.nn.Parameter(torch.zeros([convnext_feature_dim]))
+
+
+ def forward(self, sem, mask):
+ B = sem.shape[0]
+
+ # token from edge map
+ sem = torch.nn.functional.interpolate(sem, self.resize_input, mode="nearest")
+ sem = self.in_conv(sem)
+ sem_feature = self.convnext_tiny_backbone(sem)
+ objs = sem_feature.reshape(B, -1, self.num_tokens)
+ objs = objs.permute(0, 2, 1) # N*Num_tokens*dim
+
+ # expand null token
+ null_objs = self.null_feature.view(1,1,-1)
+ null_objs = null_objs.repeat(B,self.num_tokens,1)
+
+ # mask replacing
+ mask = mask.view(-1,1,1)
+ objs = objs*mask + null_objs*(1-mask)
+
+ # add pos
+ objs = objs + self.pos_embedding
+
+ # fuse them
+ objs = self.linears(objs)
+
+ assert objs.shape == torch.Size([B,self.num_tokens,self.out_dim])
+ return objs
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/text_grounding_net.py b/gligen/ldm/modules/diffusionmodules/text_grounding_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..288bb99290ebe828d0a191ab1a48b640d9f450cc
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/text_grounding_net.py
@@ -0,0 +1,50 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+
+
+
+class PositionNet(nn.Module):
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
+ self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy
+
+ self.linears = nn.Sequential(
+ nn.Linear( self.in_dim + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
+
+
+ def forward(self, boxes, masks, positive_embeddings):
+ B, N, _ = boxes.shape
+ masks = masks.unsqueeze(-1)
+
+ # embedding position (it may includes padding as placeholder)
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
+
+ # learnable null embedding
+ positive_null = self.null_positive_feature.view(1,1,-1)
+ xyxy_null = self.null_position_feature.view(1,1,-1)
+
+ # replace padding with learnable null embedding
+ positive_embeddings = positive_embeddings*masks + (1-masks)*positive_null
+ xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null
+
+ objs = self.linears( torch.cat([positive_embeddings, xyxy_embedding], dim=-1) )
+ assert objs.shape == torch.Size([B,N,self.out_dim])
+ return objs
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/text_image_grounding_net.py b/gligen/ldm/modules/diffusionmodules/text_image_grounding_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..d712c70c6ed1318d5977619c825905a3f722a857
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/text_image_grounding_net.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+from ldm.modules.attention import BasicTransformerBlock
+from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder
+import torch.nn.functional as F
+
+
+
+class PositionNet(nn.Module):
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
+ self.position_dim = fourier_freqs*2*4 # 2 is sin&cos, 4 is xyxy
+
+ # -------------------------------------------------------------- #
+ self.linears_text = nn.Sequential(
+ nn.Linear( self.in_dim + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ self.linears_image = nn.Sequential(
+ nn.Linear( self.in_dim + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear( 512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+
+ # -------------------------------------------------------------- #
+ self.null_text_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
+ self.null_image_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
+
+
+ def forward(self, boxes, masks, text_masks, image_masks, text_embeddings, image_embeddings):
+ B, N, _ = boxes.shape
+ masks = masks.unsqueeze(-1) # B*N*1
+ text_masks = text_masks.unsqueeze(-1) # B*N*1
+ image_masks = image_masks.unsqueeze(-1) # B*N*1
+
+ # embedding position (it may includes padding as placeholder)
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
+
+ # learnable null embedding
+ text_null = self.null_text_feature.view(1,1,-1) # 1*1*C
+ image_null = self.null_image_feature.view(1,1,-1) # 1*1*C
+ xyxy_null = self.null_position_feature.view(1,1,-1) # 1*1*C
+
+ # replace padding with learnable null embedding
+ text_embeddings = text_embeddings*text_masks + (1-text_masks)*text_null
+ image_embeddings = image_embeddings*image_masks + (1-image_masks)*image_null
+ xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null
+
+ objs_text = self.linears_text( torch.cat([text_embeddings, xyxy_embedding], dim=-1) )
+ objs_image = self.linears_image( torch.cat([image_embeddings,xyxy_embedding], dim=-1) )
+ objs = torch.cat( [objs_text,objs_image], dim=1 )
+
+ assert objs.shape == torch.Size([B,N*2,self.out_dim])
+ return objs
+
+
+
diff --git a/gligen/ldm/modules/diffusionmodules/util.py b/gligen/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..753ddfbdd20fdfbf9ce72d960fadf76abfbca6d7
--- /dev/null
+++ b/gligen/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,277 @@
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+
+class FourierEmbedder():
+ def __init__(self, num_freqs=64, temperature=100):
+
+ self.num_freqs = num_freqs
+ self.temperature = temperature
+ self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs )
+
+ @ torch.no_grad()
+ def __call__(self, x, cat_dim=-1):
+ "x: arbitrary shape of tensor. dim: cat dim"
+ out = []
+ for freq in self.freq_bands:
+ out.append( torch.sin( freq*x ) )
+ out.append( torch.cos( freq*x ) )
+ return torch.cat(out, cat_dim)
+
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+ #return super().forward(x).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/gligen/ldm/modules/distributions/__init__.py b/gligen/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9e4a6a210e228373cf9c9c6f3f9455029c4d145
Binary files /dev/null and b/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..663f17e84df6d157242a63297c17dc0f4aa7b926
Binary files /dev/null and b/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/distributions/distributions.py b/gligen/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/gligen/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/gligen/ldm/modules/ema.py b/gligen/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5
--- /dev/null
+++ b/gligen/ldm/modules/ema.py
@@ -0,0 +1,76 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
+ else torch.tensor(-1,dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ #remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.','')
+ self.m_name2s_name.update({name:s_name})
+ self.register_buffer(s_name,p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self,model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/gligen/ldm/modules/encoders/__init__.py b/gligen/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..732165b4f79c8221b53aaf08739ccd2134a0adff
Binary files /dev/null and b/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d872efd7c5506a0bb150b9132246e35d4a5c3369
Binary files /dev/null and b/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ
diff --git a/gligen/ldm/modules/encoders/modules.py b/gligen/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..63eb8244924c71e101e6908f913e1ee51815525e
--- /dev/null
+++ b/gligen/ldm/modules/encoders/modules.py
@@ -0,0 +1,245 @@
+import torch
+import torch.nn as nn
+from functools import partial
+import clip
+from einops import rearrange, repeat
+from transformers import CLIPTokenizer, CLIPTextModel
+import kornia
+
+from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+
+ def forward(self, batch, key=None):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ c = self.embedding(c)
+ return c
+
+
+class TransformerEmbedder(AbstractEncoder):
+ """Some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
+ super().__init__()
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
+
+ def forward(self, tokens):
+ tokens = tokens.to(self.device) # meh
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, x):
+ return self(x)
+
+
+class BERTTokenizer(AbstractEncoder):
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
+ super().__init__()
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
+ self.device = device
+ self.vq_interface = vq_interface
+ self.max_length = max_length
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt",
+ return_offsets_mapping=True)
+ tokens = batch_encoding["input_ids"].to(self.device)
+ offset_mapping = batch_encoding["offset_mapping"]
+ return tokens, offset_mapping
+
+ @torch.no_grad()
+ def encode(self, text):
+ tokens = self(text)
+ if not self.vq_interface:
+ return tokens
+ return None, None, [None, None, tokens]
+
+ def decode(self, text):
+ return text
+
+
+class BERTEmbedder(AbstractEncoder):
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
+ super().__init__()
+ self.use_tknz_fn = use_tokenizer
+ if self.use_tknz_fn:
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
+ emb_dropout=embedding_dropout)
+
+ def forward(self, text, return_offset_mapping=False):
+ if self.use_tknz_fn:
+ tokens, offset_mapping = self.tknz_fn(text)#.to(self.device)
+ else:
+ assert False
+ tokens = text
+ z = self.transformer(tokens, return_embeddings=True)
+
+ if return_offset_mapping:
+ return z, offset_mapping
+ else:
+ return z
+
+ def encode(self, text, return_offset_mapping=False):
+ # output of length 77
+ return self(text, return_offset_mapping)
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(self,
+ n_stages=1,
+ method='bilinear',
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None
+ if self.remap_output:
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
+
+ def forward(self,x):
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text, return_pooler_output=False):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+
+ if not return_pooler_output:
+ return z
+ else:
+ return z, outputs.pooler_output
+
+ def encode(self, text, return_pooler_output=False):
+ return self(text, return_pooler_output)
+
+
+class FrozenCLIPTextEmbedder(nn.Module):
+ """
+ Uses the CLIP transformer encoder for text.
+ """
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
+ super().__init__()
+ self.model, _ = clip.load(version, jit=False, device="cpu")
+ self.device = device
+ self.max_length = max_length
+ self.n_repeat = n_repeat
+ self.normalize = normalize
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = clip.tokenize(text).to(self.device)
+ z = self.model.encode_text(tokens)
+ if self.normalize:
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
+ return z
+
+ def encode(self, text):
+ z = self(text)
+ if z.ndim==2:
+ z = z[:, None, :]
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
+ return z
+
+
+class FrozenClipImageEmbedder(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ return self.model.encode_image(self.preprocess(x))
+
+
+if __name__ == "__main__":
+ from ldm.util import count_params
+ model = FrozenCLIPEmbedder()
+ count_params(model, verbose=True)
\ No newline at end of file
diff --git a/gligen/ldm/modules/encoders/modules_backup.py b/gligen/ldm/modules/encoders/modules_backup.py
new file mode 100644
index 0000000000000000000000000000000000000000..ededbe43e9e0466b9979079060692e38f561d4d3
--- /dev/null
+++ b/gligen/ldm/modules/encoders/modules_backup.py
@@ -0,0 +1,234 @@
+import torch
+import torch.nn as nn
+from functools import partial
+import clip
+from einops import rearrange, repeat
+from transformers import CLIPTokenizer, CLIPTextModel
+import kornia
+
+from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+
+ def forward(self, batch, key=None):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ c = self.embedding(c)
+ return c
+
+
+class TransformerEmbedder(AbstractEncoder):
+ """Some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
+ super().__init__()
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
+
+ def forward(self, tokens):
+ tokens = tokens.to(self.device) # meh
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, x):
+ return self(x)
+
+
+class BERTTokenizer(AbstractEncoder):
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
+ super().__init__()
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
+ self.device = device
+ self.vq_interface = vq_interface
+ self.max_length = max_length
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ return tokens
+
+ @torch.no_grad()
+ def encode(self, text):
+ tokens = self(text)
+ if not self.vq_interface:
+ return tokens
+ return None, None, [None, None, tokens]
+
+ def decode(self, text):
+ return text
+
+
+class BERTEmbedder(AbstractEncoder):
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
+ super().__init__()
+ self.use_tknz_fn = use_tokenizer
+ if self.use_tknz_fn:
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
+ emb_dropout=embedding_dropout)
+
+ def forward(self, text):
+ if self.use_tknz_fn:
+ tokens = self.tknz_fn(text)#.to(self.device)
+ else:
+ tokens = text
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, text):
+ # output of length 77
+ return self(text)
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(self,
+ n_stages=1,
+ method='bilinear',
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None
+ if self.remap_output:
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
+
+ def forward(self,x):
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPTextEmbedder(nn.Module):
+ """
+ Uses the CLIP transformer encoder for text.
+ """
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
+ super().__init__()
+ self.model, _ = clip.load(version, jit=False, device="cpu")
+ self.device = device
+ self.max_length = max_length
+ self.n_repeat = n_repeat
+ self.normalize = normalize
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = clip.tokenize(text).to(self.device)
+ z = self.model.encode_text(tokens)
+ if self.normalize:
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
+ return z
+
+ def encode(self, text):
+ z = self(text)
+ if z.ndim==2:
+ z = z[:, None, :]
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
+ return z
+
+
+class FrozenClipImageEmbedder(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ return self.model.encode_image(self.preprocess(x))
+
+
+if __name__ == "__main__":
+ from ldm.util import count_params
+ model = FrozenCLIPEmbedder()
+ count_params(model, verbose=True)
\ No newline at end of file
diff --git a/gligen/ldm/modules/image_degradation/__init__.py b/gligen/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/gligen/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/gligen/ldm/modules/image_degradation/bsrgan.py b/gligen/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/gligen/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/gligen/ldm/modules/image_degradation/bsrgan_light.py b/gligen/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1f823996bf559e9b015ea9aa2b3cd38dd13af1
--- /dev/null
+++ b/gligen/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,650 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/gligen/ldm/modules/image_degradation/utils_image.py b/gligen/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/gligen/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/gligen/ldm/modules/losses/__init__.py b/gligen/ldm/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5
--- /dev/null
+++ b/gligen/ldm/modules/losses/__init__.py
@@ -0,0 +1 @@
+from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
diff --git a/gligen/ldm/modules/losses/contperceptual.py b/gligen/ldm/modules/losses/contperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..672c1e32a1389def02461c0781339681060c540e
--- /dev/null
+++ b/gligen/ldm/modules/losses/contperceptual.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+
+from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
+
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_loss="hinge"):
+
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train",
+ weights=None):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights*nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
+
diff --git a/gligen/ldm/modules/losses/vqperceptual.py b/gligen/ldm/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306
--- /dev/null
+++ b/gligen/ldm/modules/losses/vqperceptual.py
@@ -0,0 +1,167 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import repeat
+
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+
+def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
+ loss_real = (weights * loss_real).sum() / weights.sum()
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def measure_perplexity(predicted_indices, n_embed):
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
+
+def l1(x, y):
+ return torch.abs(x-y)
+
+
+def l2(x, y):
+ return torch.pow((x-y), 2)
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
+ pixel_loss="l1"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ assert perceptual_loss in ["lpips", "clips", "dists"]
+ assert pixel_loss in ["l1", "l2"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ if perceptual_loss == "lpips":
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
+ self.perceptual_loss = LPIPS().eval()
+ else:
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
+ self.perceptual_weight = perceptual_weight
+
+ if pixel_loss == "l1":
+ self.pixel_loss = l1
+ else:
+ self.pixel_loss = l2
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+ self.n_classes = n_classes
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
+ if not exists(codebook_loss):
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ if predicted_indices is not None:
+ assert self.n_classes is not None
+ with torch.no_grad():
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
+ log[f"{split}/perplexity"] = perplexity
+ log[f"{split}/cluster_usage"] = cluster_usage
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/gligen/ldm/modules/x_transformer.py b/gligen/ldm/modules/x_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576
--- /dev/null
+++ b/gligen/ldm/modules/x_transformer.py
@@ -0,0 +1,641 @@
+"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from functools import partial
+from inspect import isfunction
+from collections import namedtuple
+from einops import rearrange, repeat, reduce
+
+# constants
+
+DEFAULT_DIM_HEAD = 64
+
+Intermediates = namedtuple('Intermediates', [
+ 'pre_softmax_attn',
+ 'post_softmax_attn'
+])
+
+LayerIntermediates = namedtuple('Intermediates', [
+ 'hiddens',
+ 'attn_intermediates'
+])
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.emb = nn.Embedding(max_seq_len, dim)
+ self.init_()
+
+ def init_(self):
+ nn.init.normal_(self.emb.weight, std=0.02)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ return self.emb(n)[None, :, :]
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return emb[None, :, :]
+
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def always(val):
+ def inner(*args, **kwargs):
+ return val
+ return inner
+
+
+def not_equals(val):
+ def inner(x):
+ return x != val
+ return inner
+
+
+def equals(val):
+ def inner(x):
+ return x == val
+ return inner
+
+
+def max_neg_value(tensor):
+ return -torch.finfo(tensor.dtype).max
+
+
+# keyword argument helpers
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(), dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+
+def string_begins_with(prefix, str):
+ return str.startswith(prefix)
+
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
+ return kwargs_without_prefix, kwargs
+
+
+# classes
+class Scale(nn.Module):
+ def __init__(self, value, fn):
+ super().__init__()
+ self.value = value
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.value, *rest)
+
+
+class Rezero(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+ self.g = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.g, *rest)
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-8):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class Residual(nn.Module):
+ def forward(self, x, residual):
+ return x + residual
+
+
+class GRUGating(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gru = nn.GRUCell(dim, dim)
+
+ def forward(self, x, residual):
+ gated_output = self.gru(
+ rearrange(x, 'b n d -> (b n) d'),
+ rearrange(residual, 'b n d -> (b n) d')
+ )
+
+ return gated_output.reshape_as(x)
+
+
+# feedforward
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# attention.
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head=DEFAULT_DIM_HEAD,
+ heads=8,
+ causal=False,
+ mask=None,
+ talking_heads=False,
+ sparse_topk=None,
+ use_entmax15=False,
+ num_mem_kv=0,
+ dropout=0.,
+ on_attn=False
+ ):
+ super().__init__()
+ if use_entmax15:
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.causal = causal
+ self.mask = mask
+
+ inner_dim = dim_head * heads
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ # talking heads
+ self.talking_heads = talking_heads
+ if talking_heads:
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+
+ # explicit topk sparse attention
+ self.sparse_topk = sparse_topk
+
+ # entmax
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
+ self.attn_fn = F.softmax
+
+ # add memory key / values
+ self.num_mem_kv = num_mem_kv
+ if num_mem_kv > 0:
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+
+ # attention on attention
+ self.attn_on_attn = on_attn
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ rel_pos=None,
+ sinusoidal_emb=None,
+ prev_attn=None,
+ mem=None
+ ):
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
+ kv_input = default(context, x)
+
+ q_input = x
+ k_input = kv_input
+ v_input = kv_input
+
+ if exists(mem):
+ k_input = torch.cat((mem, k_input), dim=-2)
+ v_input = torch.cat((mem, v_input), dim=-2)
+
+ if exists(sinusoidal_emb):
+ # in shortformer, the query would start at a position offset depending on the past cached memory
+ offset = k_input.shape[-2] - q_input.shape[-2]
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
+ k_input = k_input + sinusoidal_emb(k_input)
+
+ q = self.to_q(q_input)
+ k = self.to_k(k_input)
+ v = self.to_v(v_input)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
+
+ input_mask = None
+ if any(map(exists, (mask, context_mask))):
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
+ k_mask = q_mask if not exists(context) else context_mask
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
+ input_mask = q_mask * k_mask
+
+ if self.num_mem_kv > 0:
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
+ k = torch.cat((mem_k, k), dim=-2)
+ v = torch.cat((mem_v, v), dim=-2)
+ if exists(input_mask):
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+ mask_value = max_neg_value(dots)
+
+ if exists(prev_attn):
+ dots = dots + prev_attn
+
+ pre_softmax_attn = dots
+
+ if talking_heads:
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
+
+ if exists(rel_pos):
+ dots = rel_pos(dots)
+
+ if exists(input_mask):
+ dots.masked_fill_(~input_mask, mask_value)
+ del input_mask
+
+ if self.causal:
+ i, j = dots.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
+ mask = F.pad(mask, (j - i, 0), value=False)
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
+ mask = dots < vk
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ attn = self.attn_fn(dots, dim=-1)
+ post_softmax_attn = attn
+
+ attn = self.dropout(attn)
+
+ if talking_heads:
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ intermediates = Intermediates(
+ pre_softmax_attn=pre_softmax_attn,
+ post_softmax_attn=post_softmax_attn
+ )
+
+ return self.to_out(out), intermediates
+
+
+class AttentionLayers(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ heads=8,
+ causal=False,
+ cross_attend=False,
+ only_cross=False,
+ use_scalenorm=False,
+ use_rmsnorm=False,
+ use_rezero=False,
+ rel_pos_num_buckets=32,
+ rel_pos_max_distance=128,
+ position_infused_attn=False,
+ custom_layers=None,
+ sandwich_coef=None,
+ par_ratio=None,
+ residual_attn=False,
+ cross_residual_attn=False,
+ macaron=False,
+ pre_norm=True,
+ gate_residual=False,
+ **kwargs
+ ):
+ super().__init__()
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
+
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
+
+ self.dim = dim
+ self.depth = depth
+ self.layers = nn.ModuleList([])
+
+ self.has_pos_emb = position_infused_attn
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
+ self.rotary_pos_emb = always(None)
+
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
+ self.rel_pos = None
+
+ self.pre_norm = pre_norm
+
+ self.residual_attn = residual_attn
+ self.cross_residual_attn = cross_residual_attn
+
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
+ norm_class = RMSNorm if use_rmsnorm else norm_class
+ norm_fn = partial(norm_class, dim)
+
+ norm_fn = nn.Identity if use_rezero else norm_fn
+ branch_fn = Rezero if use_rezero else None
+
+ if cross_attend and not only_cross:
+ default_block = ('a', 'c', 'f')
+ elif cross_attend and only_cross:
+ default_block = ('c', 'f')
+ else:
+ default_block = ('a', 'f')
+
+ if macaron:
+ default_block = ('f',) + default_block
+
+ if exists(custom_layers):
+ layer_types = custom_layers
+ elif exists(par_ratio):
+ par_depth = depth * len(default_block)
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
+ default_block = tuple(filter(not_equals('f'), default_block))
+ par_attn = par_depth // par_ratio
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
+ par_block = default_block + ('f',) * (par_width - len(default_block))
+ par_head = par_block * par_attn
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
+ elif exists(sandwich_coef):
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
+ else:
+ layer_types = default_block * depth
+
+ self.layer_types = layer_types
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
+
+ for layer_type in self.layer_types:
+ if layer_type == 'a':
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
+ elif layer_type == 'c':
+ layer = Attention(dim, heads=heads, **attn_kwargs)
+ elif layer_type == 'f':
+ layer = FeedForward(dim, **ff_kwargs)
+ layer = layer if not macaron else Scale(0.5, layer)
+ else:
+ raise Exception(f'invalid layer type {layer_type}')
+
+ if isinstance(layer, Attention) and exists(branch_fn):
+ layer = branch_fn(layer)
+
+ if gate_residual:
+ residual_fn = GRUGating(dim)
+ else:
+ residual_fn = Residual()
+
+ self.layers.append(nn.ModuleList([
+ norm_fn(),
+ layer,
+ residual_fn
+ ]))
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ mems=None,
+ return_hiddens=False
+ ):
+ hiddens = []
+ intermediates = []
+ prev_attn = None
+ prev_cross_attn = None
+
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
+
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ is_last = ind == (len(self.layers) - 1)
+
+ if layer_type == 'a':
+ hiddens.append(x)
+ layer_mem = mems.pop(0)
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == 'a':
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
+ prev_attn=prev_attn, mem=layer_mem)
+ elif layer_type == 'c':
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
+ elif layer_type == 'f':
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if layer_type in ('a', 'c'):
+ intermediates.append(inter)
+
+ if layer_type == 'a' and self.residual_attn:
+ prev_attn = inter.pre_softmax_attn
+ elif layer_type == 'c' and self.cross_residual_attn:
+ prev_cross_attn = inter.pre_softmax_attn
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ if return_hiddens:
+ intermediates = LayerIntermediates(
+ hiddens=hiddens,
+ attn_intermediates=intermediates
+ )
+
+ return x, intermediates
+
+ return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs):
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
+ super().__init__(causal=False, **kwargs)
+
+
+
+class TransformerWrapper(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_tokens,
+ max_seq_len,
+ attn_layers,
+ emb_dim=None,
+ max_mem_len=0.,
+ emb_dropout=0.,
+ num_memory_tokens=None,
+ tie_embedding=False,
+ use_pos_emb=True
+ ):
+ super().__init__()
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
+
+ dim = attn_layers.dim
+ emb_dim = default(emb_dim, dim)
+
+ self.max_seq_len = max_seq_len
+ self.max_mem_len = max_mem_len
+ self.num_tokens = num_tokens
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ self.emb_dropout = nn.Dropout(emb_dropout)
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self.init_()
+
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
+
+ # memory tokens (like [cls]) from Memory Transformers paper
+ num_memory_tokens = default(num_memory_tokens, 0)
+ self.num_memory_tokens = num_memory_tokens
+ if num_memory_tokens > 0:
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
+
+ # let funnel encoder know number of memory tokens, if specified
+ if hasattr(attn_layers, 'num_memory_tokens'):
+ attn_layers.num_memory_tokens = num_memory_tokens
+
+ def init_(self):
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_mems=False,
+ return_attn=False,
+ mems=None,
+ **kwargs
+ ):
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
+ x = self.token_emb(x)
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+
+ if num_mem > 0:
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
+ x = torch.cat((mem, x), dim=1)
+
+ # auto-handle masking after appending memory tokens
+ if exists(mask):
+ mask = F.pad(mask, (num_mem, 0), value=True)
+
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
+ x = self.norm(x)
+
+ mem, x = x[:, :num_mem], x[:, num_mem:]
+
+ out = self.to_logits(x) if not return_embeddings else x
+
+ if return_mems:
+ hiddens = intermediates.hiddens
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
+ return out, new_mems
+
+ if return_attn:
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
+ return out, attn_maps
+
+ return out
+
diff --git a/gligen/ldm/util.py b/gligen/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..51839cb1478d9fecb293277dc83d2693e3d26de4
--- /dev/null
+++ b/gligen/ldm/util.py
@@ -0,0 +1,86 @@
+import importlib
+
+import torch
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
\ No newline at end of file
diff --git a/gligen/projection_matrix.pth b/gligen/projection_matrix.pth
new file mode 100644
index 0000000000000000000000000000000000000000..569755bac8eefc56a2f770cb7a3b53a28a14c14b
--- /dev/null
+++ b/gligen/projection_matrix.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45502f972194bf4e8f00d0ded08cad902587a7fc3ad2a0f5f005d058f79b3035
+size 132
diff --git a/gligen/task_grounded_generation.py b/gligen/task_grounded_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b7d7d276a15271047f18ed00599677f4de7b54e
--- /dev/null
+++ b/gligen/task_grounded_generation.py
@@ -0,0 +1,343 @@
+import argparse
+from PIL import Image, ImageDraw
+from evaluator import Evaluator
+from omegaconf import OmegaConf
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+import os
+from transformers import CLIPProcessor, CLIPModel
+from copy import deepcopy
+import torch
+from ldm.util import instantiate_from_config
+from trainer import read_official_ckpt, batch_to_device
+from evaluator import set_alpha_scale, save_images, draw_masks_from_boxes
+import numpy as np
+import clip
+from functools import partial
+import torchvision.transforms.functional as F
+import random
+
+
+device = "cuda"
+
+
+def alpha_generator(length, type=[1,0,0]):
+ """
+ length is total timestpes needed for sampling.
+ type should be a list containing three values which sum should be 1
+
+ It means the percentage of three stages:
+ alpha=1 stage
+ linear deacy stage
+ alpha=0 stage.
+
+ For example if length=100, type=[0.8,0.1,0.1]
+ then the first 800 stpes, alpha will be 1, and then linearly decay to 0 in the next 100 steps,
+ and the last 100 stpes are 0.
+ """
+
+ assert len(type)==3
+ assert type[0] + type[1] + type[2] == 1
+
+ stage0_length = int(type[0]*length)
+ stage1_length = int(type[1]*length)
+ stage2_length = length - stage0_length - stage1_length
+
+ if stage1_length != 0:
+ decay_alphas = np.arange(start=0, stop=1, step=1/stage1_length)[::-1]
+ decay_alphas = list(decay_alphas)
+ else:
+ decay_alphas = []
+
+
+ alphas = [1]*stage0_length + decay_alphas + [0]*stage2_length
+
+ assert len(alphas) == length
+
+ return alphas
+
+
+def draw_box(img, locations):
+ colors = ["red", "green", "blue", "olive", "orange", "brown", "cyan", "purple"]
+ draw = ImageDraw.Draw(img)
+ WW,HH = img.size
+ for bid, box in enumerate(locations):
+ draw.rectangle([box[0]*WW, box[1]*HH, box[2]*WW, box[3]*HH], outline =colors[bid % len(colors)], width=5)
+ return img
+
+def load_common_ckpt(config, common_ckpt):
+ autoencoder = instantiate_from_config(config.autoencoder).to(device).eval()
+ text_encoder = instantiate_from_config(config.text_encoder).to(device).eval()
+ diffusion = instantiate_from_config(config.diffusion).to(device)
+
+ autoencoder.load_state_dict( common_ckpt["autoencoder"] )
+ text_encoder.load_state_dict( common_ckpt["text_encoder"] )
+ diffusion.load_state_dict( common_ckpt["diffusion"] )
+
+ return [autoencoder, text_encoder, diffusion]
+
+def load_ckpt(config, state_dict, common_instances):
+
+ model = instantiate_from_config(config.model).to(device)
+
+ model.load_state_dict(state_dict['model'])
+ set_alpha_scale(model, config.alpha_scale)
+
+ print("ckpt is loaded")
+
+ return [model] + common_instances
+
+
+
+
+def project(x, projection_matrix):
+ """
+ x (Batch*768) should be the penultimate feature of CLIP (before projection)
+ projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer
+ defined in CLIP (out_dim, in_dim), thus we need to apply transpose below.
+ this function will return the CLIP feature (without normalziation)
+ """
+ return x@torch.transpose(projection_matrix, 0, 1)
+
+@torch.no_grad()
+def get_clip_feature(model, processor, input, is_image=False):
+ feature_type = ['before','after_reproject'] # text feature, image feature
+
+ if is_image:
+ image = input #Image.open(input).convert("RGB")
+ inputs = processor(images=[image], return_tensors="pt", padding=True)
+ inputs['pixel_values'] = inputs['pixel_values'].cuda() # we use our own preprocessing without center_crop
+ inputs['input_ids'] = torch.tensor([[0,1,2,3]]).cuda() # placeholder
+ outputs = model(**inputs)
+ feature = outputs.image_embeds
+ if feature_type[1] == 'after_renorm':
+ feature = feature*28.7
+ if feature_type[1] == 'after_reproject':
+ feature = project( feature, torch.load('gligen/projection_matrix.pth').cuda().T ).squeeze(0)
+ feature = ( feature / feature.norm() ) * 28.7
+ feature = feature.unsqueeze(0)
+ else:
+ inputs = processor(text=input, return_tensors="pt", padding=True)
+ inputs['input_ids'] = inputs['input_ids'].cuda()
+ inputs['pixel_values'] = torch.ones(1,3,224,224).cuda() # placeholder
+ inputs['attention_mask'] = inputs['attention_mask'].cuda()
+ outputs = model(**inputs)
+ feature = outputs.text_embeds if feature_type[0] == 'after' else outputs.text_model_output.pooler_output
+ return feature
+
+
+
+def complete_mask(has_mask, max_objs):
+ mask = torch.ones(1,max_objs)
+ if type(has_mask) == int or type(has_mask) == float:
+ return mask * has_mask
+ else:
+ for idx, value in enumerate(has_mask):
+ mask[0,idx] = value
+ return mask
+
+
+
+@torch.no_grad()
+def fire_clip(text_encoder, meta, batch=1, max_objs=30, clip_model=None):
+ # import pdb; pdb.set_trace()
+ phrases = meta["phrases"]
+ images = meta["images"]
+
+ if clip_model is None:
+ version = "openai/clip-vit-large-patch14"
+ model = CLIPModel.from_pretrained(version).cuda()
+ processor = CLIPProcessor.from_pretrained(version)
+ else:
+ version = "openai/clip-vit-large-patch14"
+ assert clip_model['version'] == version
+ model = clip_model['model']
+ processor = clip_model['processor']
+
+ boxes = torch.zeros(max_objs, 4)
+ masks = torch.zeros(max_objs)
+ text_embeddings = torch.zeros(max_objs, 768)
+ image_embeddings = torch.zeros(max_objs, 768)
+
+
+ text_features = []
+ image_features = []
+ for phrase, image in zip(phrases,images):
+ text_features.append( get_clip_feature(model, processor, phrase, is_image=False) )
+ image_features.append( get_clip_feature(model, processor, image, is_image=True) )
+
+ if len(text_features) > 0:
+ text_features = torch.cat(text_features, dim=0)
+ image_features = torch.cat(image_features, dim=0)
+
+ for idx, (box, text_feature, image_feature) in enumerate(zip( meta['locations'], text_features, image_features)):
+ boxes[idx] = torch.tensor(box)
+ masks[idx] = 1
+ text_embeddings[idx] = text_feature
+ image_embeddings[idx] = image_feature
+
+
+ out = {
+ "boxes" : boxes.unsqueeze(0).repeat(batch,1,1),
+ "masks" : masks.unsqueeze(0).repeat(batch,1),
+ "text_masks" : masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta["has_text_mask"], max_objs ),
+ "image_masks" : masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta["has_image_mask"], max_objs ),
+ "text_embeddings" : text_embeddings.unsqueeze(0).repeat(batch,1,1),
+ "image_embeddings" : image_embeddings.unsqueeze(0).repeat(batch,1,1)
+ }
+ return batch_to_device(out, device)
+
+def remove_numbers(text):
+ result = ''.join([char for char in text if not char.isdigit()])
+ return result
+def process_box_phrase(names, bboxes):
+ d = {}
+ for i, phrase in enumerate(names):
+ phrase = phrase.replace('_',' ')
+ list_noun = phrase.split(' ')
+ for n in list_noun:
+ n = remove_numbers(n)
+ if not n in d.keys():
+ d.update({n:[np.array(bboxes[i])]})
+ else:
+ d[n].append(np.array(bboxes[i]))
+ return d
+
+def Pharse2idx_2(prompt, name_box):
+ prompt = prompt.replace('.','')
+ prompt = prompt.replace(',','')
+ prompt_list = prompt.strip('.').split(' ')
+ object_positions = []
+ bbox_to_self_att = []
+ for obj in name_box.keys():
+ obj_position = []
+ in_prompt = False
+ for word in obj.split(' '):
+ if word in prompt_list:
+ obj_first_index = prompt_list.index(word) + 1
+ obj_position.append(obj_first_index)
+ in_prompt = True
+ elif word +'s' in prompt_list:
+ obj_first_index = prompt_list.index(word+'s') + 1
+ obj_position.append(obj_first_index)
+ in_prompt = True
+ elif word +'es' in prompt_list:
+ obj_first_index = prompt_list.index(word+'es') + 1
+ obj_position.append(obj_first_index)
+ in_prompt = True
+ if in_prompt :
+ bbox_to_self_att.append(np.array(name_box[obj]))
+
+ object_positions.append(obj_position)
+
+ return object_positions, bbox_to_self_att
+
+
+
+
+
+# @torch.no_grad()
+def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs):
+
+ # -------------- prepare model and misc --------------- #
+
+ model, autoencoder, text_encoder, diffusion = loaded_model_list
+
+ batch_size = instruction["batch_size"]
+ is_inpaint = True if "input_image" in instruction else False
+ save_folder = os.path.join("create_samples", instruction["save_folder_name"])
+
+
+ # -------------- set seed if required --------------- #
+ if instruction.get('fix_seed', False):
+ random_seed = instruction['rand_seed']
+ random.seed(random_seed)
+ np.random.seed(random_seed)
+ torch.manual_seed(random_seed)
+
+ # ------------- prepare input for the model ------------- #
+ with torch.no_grad():
+ batch = fire_clip(text_encoder, instruction, batch_size, clip_model=kwargs.get('clip_model', None))
+ context = text_encoder.encode( [instruction["prompt"]]*batch_size )
+ uc = text_encoder.encode( batch_size*[""] )
+ name_box = process_box_phrase(instruction['phrases'], instruction['locations'])
+
+ position, box_att = Pharse2idx_2(instruction['prompt'],name_box )
+ input = dict(x = None,
+ timesteps = None,
+ context = context,
+ boxes = batch['boxes'],
+ masks = batch['masks'],
+ text_masks = batch['text_masks'],
+ image_masks = batch['image_masks'],
+ text_embeddings = batch["text_embeddings"],
+ image_embeddings = batch["image_embeddings"],
+ boxes_att=box_att,
+ object_position = position )
+
+ inpainting_mask = x0 = None # used for inpainting
+ if is_inpaint:
+ input_image = F.pil_to_tensor( instruction["input_image"] )
+ input_image = ( input_image.float().unsqueeze(0).cuda() / 255 - 0.5 ) / 0.5
+ x0 = autoencoder.encode( input_image )
+ if instruction["actual_mask"] is not None:
+ inpainting_mask = instruction["actual_mask"][None, None].expand(batch['boxes'].shape[0], -1, -1, -1).cuda()
+ else:
+ actual_boxes = [instruction['inpainting_boxes_nodrop'] for _ in range(batch['boxes'].shape[0])]
+ inpainting_mask = draw_masks_from_boxes(actual_boxes, (x0.shape[-2], x0.shape[-1]) ).cuda()
+ masked_x0 = x0*inpainting_mask
+ inpainting_extra_input = torch.cat([masked_x0,inpainting_mask], dim=1)
+ input["inpainting_extra_input"] = inpainting_extra_input
+
+
+ # ------------- prepare sampler ------------- #
+ alpha_generator_func = partial(alpha_generator, type=instruction["alpha_type"])
+ if False:
+ sampler = DDIMSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)
+ steps = 250
+ else:
+ sampler = PLMSSampler(diffusion, model, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)
+ steps = 50
+
+ # ------------- run sampler ... ------------- #
+ shape = (batch_size, model.in_channels, model.image_size, model.image_size)
+ samples_fake = sampler.sample(S=steps, shape=shape, input=input, uc=uc, guidance_scale=instruction['guidance_scale'], mask=inpainting_mask, x0=x0)
+ with torch.no_grad():
+ samples_fake = autoencoder.decode(samples_fake)
+
+
+ # ------------- other logistics ------------- #
+
+ sample_list = []
+ for sample in samples_fake:
+ sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
+ sample = sample.cpu().numpy().transpose(1,2,0) * 255
+ sample = Image.fromarray(sample.astype(np.uint8))
+ sample_list.append(sample)
+
+ return sample_list, None
+
+
+
+# if __name__ == "__main__":
+
+
+# parser = argparse.ArgumentParser()
+# parser.add_argument("--folder", type=str, default="create_samples", help="path to OUTPUT")
+# parser.add_argument("--official_ckpt", type=str, default='../../../data/sd-v1-4.ckpt', help="")
+
+# parser.add_argument("--batch_size", type=int, default=10, help="This will overwrite the one in yaml.")
+# parser.add_argument("--no_plms", action='store_true')
+# parser.add_argument("--guidance_scale", type=float, default=5, help="")
+# parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model")
+# args = parser.parse_args()
+
+# assert "sd-v1-4.ckpt" in args.official_ckpt, "only support for stable-diffusion model"
+
+
+# grounded_generation(args)
+
+
+
+
+
diff --git a/gligen/trainer.py b/gligen/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0815b7b25579001be44674e6fa2afa2a7d9e79b0
--- /dev/null
+++ b/gligen/trainer.py
@@ -0,0 +1,456 @@
+import torch
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.util import instantiate_from_config
+import numpy as np
+import random
+import time
+from dataset.concat_dataset import ConCatDataset #, collate_fn
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+import os
+import shutil
+import torchvision
+import math
+from torch.nn.parallel import DistributedDataParallel as DDP
+from tqdm import tqdm
+from distributed import get_rank, synchronize, get_world_size
+from transformers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
+from copy import deepcopy
+try:
+ from apex import amp
+except:
+ pass
+# = = = = = = = = = = = = = = = = = = useful functions = = = = = = = = = = = = = = = = = #
+
+class ImageCaptionSaver:
+ def __init__(self, base_path, nrow=8, normalize=True, scale_each=True, range=(-1,1) ):
+ self.base_path = base_path
+ self.nrow = nrow
+ self.normalize = normalize
+ self.scale_each = scale_each
+ self.range = range
+
+ def __call__(self, images, real, captions, seen):
+
+ save_path = os.path.join(self.base_path, str(seen).zfill(8)+'.png')
+ torchvision.utils.save_image( images, save_path, nrow=self.nrow, normalize=self.normalize, scale_each=self.scale_each, range=self.range )
+
+ save_path = os.path.join(self.base_path, str(seen).zfill(8)+'_real.png')
+ torchvision.utils.save_image( real, save_path, nrow=self.nrow)
+
+ assert images.shape[0] == len(captions)
+
+ save_path = os.path.join(self.base_path, 'captions.txt')
+ with open(save_path, "a") as f:
+ f.write( str(seen).zfill(8) + ':\n' )
+ for cap in captions:
+ f.write( cap + '\n' )
+ f.write( '\n' )
+
+
+
+def read_official_ckpt(ckpt_path):
+ "Read offical pretrained ckpt and convert into my style"
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+ out = {}
+ out["model"] = {}
+ out["text_encoder"] = {}
+ out["autoencoder"] = {}
+ out["unexpected"] = {}
+ out["diffusion"] = {}
+
+ for k,v in state_dict.items():
+ if k.startswith('model.diffusion_model'):
+ out["model"][k.replace("model.diffusion_model.", "")] = v
+ elif k.startswith('cond_stage_model'):
+ out["text_encoder"][k.replace("cond_stage_model.", "")] = v
+ elif k.startswith('first_stage_model'):
+ out["autoencoder"][k.replace("first_stage_model.", "")] = v
+ elif k in ["model_ema.decay", "model_ema.num_updates"]:
+ out["unexpected"][k] = v
+ else:
+ out["diffusion"][k] = v
+ return out
+
+
+def batch_to_device(batch, device):
+ for k in batch:
+ if isinstance(batch[k], torch.Tensor):
+ batch[k] = batch[k].to(device)
+ return batch
+
+
+def sub_batch(batch, num=1):
+ # choose first num in given batch
+ num = num if num > 1 else 1
+ for k in batch:
+ batch[k] = batch[k][0:num]
+ return batch
+
+
+def wrap_loader(loader):
+ while True:
+ for batch in loader: # TODO: it seems each time you have the same order for all epoch??
+ yield batch
+
+
+def disable_grads(model):
+ for p in model.parameters():
+ p.requires_grad = False
+
+
+def count_params(params):
+ total_trainable_params_count = 0
+ for p in params:
+ total_trainable_params_count += p.numel()
+ print("total_trainable_params_count is: ", total_trainable_params_count)
+
+
+def update_ema(target_params, source_params, rate=0.99):
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def create_expt_folder_with_auto_resuming(OUTPUT_ROOT, name):
+ #curr_folder_name = os.getcwd().split("/")[-1]
+ name = os.path.join( OUTPUT_ROOT, name )
+ writer = None
+ checkpoint = None
+
+ if os.path.exists(name):
+ all_tags = os.listdir(name)
+ all_existing_tags = [ tag for tag in all_tags if tag.startswith('tag') ]
+ all_existing_tags.sort()
+ all_existing_tags = all_existing_tags[::-1]
+ for previous_tag in all_existing_tags:
+ potential_ckpt = os.path.join( name, previous_tag, 'checkpoint_latest.pth' )
+ if os.path.exists(potential_ckpt):
+ checkpoint = potential_ckpt
+ if get_rank() == 0:
+ print('ckpt found '+ potential_ckpt)
+ break
+ curr_tag = 'tag'+str(len(all_existing_tags)).zfill(2)
+ name = os.path.join( name, curr_tag ) # output/name/tagxx
+ else:
+ name = os.path.join( name, 'tag00' ) # output/name/tag00
+
+ if get_rank() == 0:
+ os.makedirs(name)
+ os.makedirs( os.path.join(name,'Log') )
+ writer = SummaryWriter( os.path.join(name,'Log') )
+
+ return name, writer, checkpoint
+
+
+
+# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
+# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
+# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
+
+
+
+
+
+
+class Trainer:
+ def __init__(self, config):
+
+ self.config = config
+ self.device = torch.device("cuda")
+
+ self.l_simple_weight = 1
+ self.name, self.writer, checkpoint = create_expt_folder_with_auto_resuming(config.OUTPUT_ROOT, config.name)
+ if get_rank() == 0:
+ shutil.copyfile(config.yaml_file, os.path.join(self.name, "train_config_file.yaml") )
+ torch.save( vars(config), os.path.join(self.name, "config_dict.pth") )
+
+ # = = = = = = = = = = create model and diffusion = = = = = = = = = = #
+ self.model = instantiate_from_config(config.model).to(self.device)
+ self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device)
+ self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device)
+ self.diffusion = instantiate_from_config(config.diffusion).to(self.device)
+
+
+ state_dict = read_official_ckpt( os.path.join(config.DATA_ROOT, config.official_ckpt_name) )
+ missing_keys, unexpected_keys = self.model.load_state_dict( state_dict["model"], strict=False )
+ assert unexpected_keys == []
+ original_params_names = list( state_dict["model"].keys() )
+ self.autoencoder.load_state_dict( state_dict["autoencoder"] )
+ self.text_encoder.load_state_dict( state_dict["text_encoder"] )
+ self.diffusion.load_state_dict( state_dict["diffusion"] )
+
+ self.autoencoder.eval()
+ self.text_encoder.eval()
+ disable_grads(self.autoencoder)
+ disable_grads(self.text_encoder)
+
+
+
+ # = = load from ckpt: (usually second stage whole model finetune) = = #
+ if self.config.ckpt is not None:
+ first_stage_ckpt = torch.load(self.config.ckpt, map_location="cpu")
+ self.model.load_state_dict(first_stage_ckpt["model"])
+
+
+
+
+ # = = = = = = = = = = create opt = = = = = = = = = = #
+ print(" ")
+ print("IMPORTANT: following code decides which params trainable!")
+ print(" ")
+
+ if self.config.whole:
+ print("Entire model is trainable")
+ params = list(self.model.parameters())
+ else:
+ print("Only new added components will be updated")
+ params = []
+ trainable_names = []
+ for name, p in self.model.named_parameters():
+ if ("transformer_blocks" in name) and ("fuser" in name):
+ params.append(p)
+ trainable_names.append(name)
+ elif "position_net" in name:
+ params.append(p)
+ trainable_names.append(name)
+ else:
+ # all new added trainable params have to be haddled above
+ # otherwise it will trigger the following error
+ assert name in original_params_names, name
+
+ all_params_name = list( self.model.state_dict().keys() )
+ assert set(all_params_name) == set(trainable_names + original_params_names)
+
+ self.opt = torch.optim.AdamW(params, lr=config.base_learning_rate, weight_decay=config.weight_decay)
+ count_params(params)
+
+ self.master_params = list(self.model.parameters()) # note: you cannot assign above params as master_params since that is only trainable one
+
+ if config.enable_ema:
+ self.ema = deepcopy(self.model)
+ self.ema_params = list(self.ema.parameters())
+ self.ema.eval()
+
+ # = = = = = = = = = = create scheduler = = = = = = = = = = #
+ if config.scheduler_type == "cosine":
+ self.scheduler = get_cosine_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_iters)
+ elif config.scheduler_type == "constant":
+ self.scheduler = get_constant_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps)
+ else:
+ assert False
+
+
+
+ # = = = = = = = = = = create data = = = = = = = = = = #
+ train_dataset_repeats = config.train_dataset_repeats if 'train_dataset_repeats' in config else None
+ dataset_train = ConCatDataset(config.train_dataset_names, config.DATA_ROOT, config.which_embedder, train=True, repeats=train_dataset_repeats)
+ sampler = DistributedSampler(dataset_train) if config.distributed else None
+ loader_train = DataLoader( dataset_train, batch_size=config.batch_size,
+ shuffle=(sampler is None),
+ num_workers=config.workers,
+ pin_memory=True,
+ sampler=sampler)
+ self.dataset_train = dataset_train
+ self.loader_train = wrap_loader(loader_train)
+
+ if get_rank() == 0:
+ total_image = dataset_train.total_images()
+ print("Total training images: ", total_image)
+
+
+ # = = = = = = = = = = load from autoresuming ckpt = = = = = = = = = = #
+ self.starting_iter = 0
+ if checkpoint is not None:
+ checkpoint = torch.load(checkpoint, map_location="cpu")
+ self.model.load_state_dict(checkpoint["model"])
+ if config.enable_ema:
+ self.ema.load_state_dict(checkpoint["ema"])
+ self.opt.load_state_dict(checkpoint["opt"])
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
+ self.starting_iter = checkpoint["iters"]
+ if self.starting_iter >= config.total_iters:
+ synchronize()
+ print("Training finished. Start exiting")
+ exit()
+
+
+ # = = = = = misc = = = = = #
+ if get_rank() == 0:
+ print("Actual total need see images is: ", config.total_iters*config.total_batch_size)
+ print("Equivalent training epoch is: ", (config.total_iters*config.total_batch_size) / len(dataset_train) )
+ self.image_caption_saver = ImageCaptionSaver(self.name)
+ # self.counter = Counter(config.total_batch_size, config.save_every_images)
+
+ if config.use_o2:
+ self.model, self.opt = amp.initialize(self.model, self.opt, opt_level="O2")
+ self.model.use_o2 = True
+
+
+ # = = = = = wrap into ddp = = = = = #
+ if config.distributed:
+ self.model = DDP( self.model, device_ids=[config.local_rank], output_device=config.local_rank, broadcast_buffers=False )
+
+
+
+
+
+ @torch.no_grad()
+ def get_input(self, batch):
+
+ z = self.autoencoder.encode( batch["image"] )
+
+ context = self.text_encoder.encode( batch["caption"] )
+
+ _t = torch.rand(z.shape[0]).to(z.device)
+ t = (torch.pow(_t, self.config.resample_step_gamma) * 1000).long()
+ t = torch.where(t!=1000, t, 999) # if 1000, then replace it with 999
+
+ return z, t, context
+
+
+ def run_one_step(self, batch):
+ x_start, t, context = self.get_input(batch)
+ noise = torch.randn_like(x_start)
+ x_noisy = self.diffusion.q_sample(x_start=x_start, t=t, noise=noise)
+
+ input = dict(x = x_noisy,
+ timesteps = t,
+ context = context,
+ boxes = batch['boxes'],
+ masks = batch['masks'],
+ text_masks = batch['text_masks'],
+ image_masks = batch['image_masks'],
+ text_embeddings = batch["text_embeddings"],
+ image_embeddings = batch["image_embeddings"] )
+ model_output = self.model(input)
+
+ loss = torch.nn.functional.mse_loss(model_output, noise) * self.l_simple_weight
+
+ self.loss_dict = {"loss": loss.item()}
+
+ return loss
+
+
+
+ def start_training(self):
+
+ if not self.config.use_o2:
+ # use pytorch mixed training which is similar to o1 but faster
+ scaler = torch.cuda.amp.GradScaler()
+
+
+ iterator = tqdm(range(self.starting_iter, self.config.total_iters), desc='Training progress', disable=get_rank() != 0 )
+ self.model.train()
+ for iter_idx in iterator: # note: iter_idx is not from 0 if resume training
+ self.iter_idx = iter_idx
+
+ self.opt.zero_grad()
+ batch = next(self.loader_train)
+ batch_to_device(batch, self.device)
+
+ if self.config.use_o2:
+ loss = self.run_one_step(batch)
+ with amp.scale_loss(loss, self.opt) as scaled_loss:
+ scaled_loss.backward()
+ self.opt.step()
+ else:
+ enabled = True if self.config.use_mixed else False
+ with torch.cuda.amp.autocast(enabled=enabled): # with torch.autocast(enabled=True):
+ loss = self.run_one_step(batch)
+ scaler.scale(loss).backward()
+ scaler.step(self.opt)
+ scaler.update()
+
+
+ self.scheduler.step()
+
+ if self.config.enable_ema:
+ update_ema(self.ema_params, self.master_params, self.config.ema_rate)
+
+
+ if (get_rank() == 0):
+ if (iter_idx % 10 == 0):
+ self.log_loss()
+ if (iter_idx == 0) or ( iter_idx % self.config.save_every_iters == 0 ) or (iter_idx == self.config.total_iters-1):
+ self.save_ckpt_and_result()
+ synchronize()
+
+
+ synchronize()
+ print("Training finished. Start exiting")
+ exit()
+
+
+ def log_loss(self):
+ for k, v in self.loss_dict.items():
+ self.writer.add_scalar( k, v, self.iter_idx+1 ) # we add 1 as the actual name
+
+
+ @torch.no_grad()
+ def save_ckpt_and_result(self):
+
+ model_wo_wrapper = self.model.module if self.config.distributed else self.model
+
+ iter_name = self.iter_idx + 1 # we add 1 as the actual name
+
+ if not self.config.disable_inference_in_training:
+ # Do a quick inference on one training batch
+ batch_here = self.config.batch_size
+ batch = sub_batch( next(self.loader_train), batch_here)
+ batch_to_device(batch, self.device)
+
+
+ real_images_with_box_drawing = [] # we save this durining trianing for better visualization
+ for i in range(batch_here):
+ temp_data = {"image": batch["image"][i], "boxes":batch["boxes"][i]}
+ im = self.dataset_train.datasets[0].vis_getitem_data(out=temp_data, return_tensor=True, print_caption=False)
+ real_images_with_box_drawing.append(im)
+ real_images_with_box_drawing = torch.stack(real_images_with_box_drawing)
+
+
+ uc = self.text_encoder.encode( batch_here*[""] )
+ context = self.text_encoder.encode( batch["caption"] )
+
+ ddim_sampler = PLMSSampler(self.diffusion, model_wo_wrapper)
+ shape = (batch_here, model_wo_wrapper.in_channels, model_wo_wrapper.image_size, model_wo_wrapper.image_size)
+ input = dict( x = None,
+ timesteps = None,
+ context = context,
+ boxes = batch['boxes'],
+ masks = batch['masks'],
+ text_masks = batch['text_masks'],
+ image_masks = batch['image_masks'],
+ text_embeddings = batch["text_embeddings"],
+ image_embeddings = batch["image_embeddings"] )
+ samples = ddim_sampler.sample(S=50, shape=shape, input=input, uc=uc, guidance_scale=5)
+
+ # old
+ # autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that.
+ # autoencoder_wo_wrapper = autoencoder_wo_wrapper.cpu() # To save GPU
+ # samples = autoencoder_wo_wrapper.decode(samples.cpu())
+ # autoencoder_wo_wrapper = autoencoder_wo_wrapper.to(self.device)
+
+ # new
+ autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that.
+ samples = autoencoder_wo_wrapper.decode(samples).cpu()
+
+ self.image_caption_saver(samples, real_images_with_box_drawing, batch["caption"], iter_name)
+
+ ckpt = dict(model = model_wo_wrapper.state_dict(),
+ opt = self.opt.state_dict(),
+ scheduler= self.scheduler.state_dict(),
+ iters = self.iter_idx+1 )
+ if self.config.enable_ema:
+ ckpt["ema"] = self.ema.state_dict()
+ torch.save( ckpt, os.path.join(self.name, "checkpoint_"+str(iter_name).zfill(8)+".pth") )
+ torch.save( ckpt, os.path.join(self.name, "checkpoint_latest.pth") )
+
+
+
+
+
+
+
+
diff --git a/guide_imgs/0_A_train_on_top_of_a_surfboard..jpg b/guide_imgs/0_A_train_on_top_of_a_surfboard..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1318f6f48ee175b459f23437c9d87e5057a605a2
Binary files /dev/null and b/guide_imgs/0_A_train_on_top_of_a_surfboard..jpg differ
diff --git a/guide_imgs/0_a_bus_on_the_left_of_a_car.jpg b/guide_imgs/0_a_bus_on_the_left_of_a_car.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..24d35534cb7d736d63835bec13769c4a7fc78275
Binary files /dev/null and b/guide_imgs/0_a_bus_on_the_left_of_a_car.jpg differ
diff --git a/guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg b/guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ec9c39c682b3a66a37a6ab8dce8d0e034f278956
Binary files /dev/null and b/guide_imgs/0_a_cat_on_the_right_of_a_dog.jpg differ
diff --git a/guide_imgs/0_two_cups_filled_with_steaming_hot_coffee_sit_side-by-side_on_a_wooden_table..jpg b/guide_imgs/0_two_cups_filled_with_steaming_hot_coffee_sit_side-by-side_on_a_wooden_table..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..deb50c00c256df9620d455c33935b8461d4f8c99
Binary files /dev/null and b/guide_imgs/0_two_cups_filled_with_steaming_hot_coffee_sit_side-by-side_on_a_wooden_table..jpg differ
diff --git a/guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg b/guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..29770c541f2fc67ae62bed1d5733729fe16119e6
Binary files /dev/null and b/guide_imgs/10_A_banana_on_the_left_of_an_apple..jpg differ
diff --git a/guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg b/guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5247285078fd13bc156407c134e8a78600d2a219
Binary files /dev/null and b/guide_imgs/15_A_pizza_on_the_right_of_a_suitcase..jpg differ
diff --git a/guide_imgs/1_A_book_on_the_right_of_a_vase_on_the_table._There_is_a_cup_of_tea_on_the_top_of_the_book.jpg b/guide_imgs/1_A_book_on_the_right_of_a_vase_on_the_table._There_is_a_cup_of_tea_on_the_top_of_the_book.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..17aeb56bafcd0b7152eea7335f2343bc44fe919a
Binary files /dev/null and b/guide_imgs/1_A_book_on_the_right_of_a_vase_on_the_table._There_is_a_cup_of_tea_on_the_top_of_the_book.jpg differ
diff --git a/guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg b/guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..00e1ee1373f4f6e480835152cd460da27878abd3
Binary files /dev/null and b/guide_imgs/1_A_wine_glass_on_top_of_a_dog..jpg differ
diff --git a/guide_imgs/1_Two_cars_on_the_street..jpg b/guide_imgs/1_Two_cars_on_the_street..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cecd67700703cf466915e8d62a80a9b9f49a26fd
Binary files /dev/null and b/guide_imgs/1_Two_cars_on_the_street..jpg differ
diff --git a/guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg b/guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f5fec00f44ade8070ac5392ac078fdea7189afab
Binary files /dev/null and b/guide_imgs/2_A_bicycle_on_top_of_a_boat..jpg differ
diff --git a/guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg b/guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..44ece39f17bbace51284f6b6251550773749db89
Binary files /dev/null and b/guide_imgs/4_A_laptop_on_top_of_a_teddy_bear..jpg differ
diff --git a/guide_imgs/70_two_cats_are_curled_up_together_on_a_sunny_windowsill,_purring_contentedly..jpg b/guide_imgs/70_two_cats_are_curled_up_together_on_a_sunny_windowsill,_purring_contentedly..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a0f9ea1efb86e911b6ea8d4d747152ecb63a93cf
Binary files /dev/null and b/guide_imgs/70_two_cats_are_curled_up_together_on_a_sunny_windowsill,_purring_contentedly..jpg differ
diff --git a/guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg b/guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5e4d2b6e47f4e11a0242e7309cd02dfb136b0cbc
Binary files /dev/null and b/guide_imgs/80_two_apples_lay_side_by_side_on_a_wooden_table,_their_glossy_red_and_green_skins_glinting_in_the_sunlight..jpg differ
diff --git a/images/arg_corgis.jpeg b/images/arg_corgis.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..04ab99bf419862226b30d64c048781ff4ba07362
Binary files /dev/null and b/images/arg_corgis.jpeg differ
diff --git a/images/blank.png b/images/blank.png
new file mode 100644
index 0000000000000000000000000000000000000000..e30ec31e4e12b52e579dcf606826e8d21cb19a03
Binary files /dev/null and b/images/blank.png differ
diff --git a/images/cat_dog.jpg b/images/cat_dog.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a27416e5a6c7befe1e956667bb1eb4caac6e042d
Binary files /dev/null and b/images/cat_dog.jpg differ
diff --git a/images/flower_beach.jpg b/images/flower_beach.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..413e0324fbddea3664976ba143b69d44c232b982
Binary files /dev/null and b/images/flower_beach.jpg differ
diff --git a/images/img.png b/images/img.png
new file mode 100644
index 0000000000000000000000000000000000000000..c71d32c25b47b434f85912881bca51a0799709de
Binary files /dev/null and b/images/img.png differ
diff --git a/images/red_bird.jpg b/images/red_bird.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..83e40c08dab493698f15015dd51ed711695dd956
Binary files /dev/null and b/images/red_bird.jpg differ
diff --git a/images/style_cloudpurple.png b/images/style_cloudpurple.png
new file mode 100644
index 0000000000000000000000000000000000000000..8388ad8e38588e3811410222a53b1a45e3ac65f5
Binary files /dev/null and b/images/style_cloudpurple.png differ
diff --git a/images/style_gold.png b/images/style_gold.png
new file mode 100644
index 0000000000000000000000000000000000000000..9484c23df501c33d333babc27c5ef25ac49cf95b
Binary files /dev/null and b/images/style_gold.png differ
diff --git a/images/teddy.jpg b/images/teddy.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ee5469f2d28ddf9ea8da5fdbbf4b3f7d9c272c59
Binary files /dev/null and b/images/teddy.jpg differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2175ca35aa128ec133d8bf09dea9eb4598559220
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,18 @@
+torch==1.13.1
+torchvision==0.14.1
+xformers==0.0.16
+omegaconf==2.1.1
+albumentations==1.3.0
+opencv-python
+imageio==2.9.0
+imageio-ffmpeg==0.4.2
+pytorch-lightning==1.4.2
+test-tube>=0.7.5
+streamlit==1.17.0
+einops==0.3.0
+git+https://github.com/openai/CLIP.git
+protobuf~=3.20.1
+torchmetrics==0.6.0
+transformers==4.19.2
+kornia==0.6.0
+gradio==3.19.1
\ No newline at end of file