ortha / mixofshow /data /pil_transform.py
ujin-song's picture
upload mixofshow and orthogonal_mats folder
8e12b4e verified
import inspect
import random
from copy import deepcopy
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from PIL import Image
from torchvision.transforms import CenterCrop, Normalize, RandomCrop, RandomHorizontalFlip, Resize
from torchvision.transforms.functional import InterpolationMode
from mixofshow.utils.registry import TRANSFORM_REGISTRY
def build_transform(opt):
"""Build performance evaluator from options.
Args:
opt (dict): Configuration.
"""
opt = deepcopy(opt)
transform_type = opt.pop('type')
transform = TRANSFORM_REGISTRY.get(transform_type)(**opt)
return transform
TRANSFORM_REGISTRY.register(Normalize)
TRANSFORM_REGISTRY.register(Resize)
TRANSFORM_REGISTRY.register(RandomHorizontalFlip)
TRANSFORM_REGISTRY.register(CenterCrop)
TRANSFORM_REGISTRY.register(RandomCrop)
@TRANSFORM_REGISTRY.register()
class BILINEARResize(Resize):
def __init__(self, size):
super(BILINEARResize,
self).__init__(size, interpolation=InterpolationMode.BILINEAR)
@TRANSFORM_REGISTRY.register()
class PairRandomCrop(nn.Module):
def __init__(self, size):
super().__init__()
if isinstance(size, int):
self.height, self.width = size, size
else:
self.height, self.width = size
def forward(self, img, **kwargs):
img_width, img_height = img.size
mask_width, mask_height = kwargs['mask'].size
assert img_height >= self.height and img_height == mask_height
assert img_width >= self.width and img_width == mask_width
x = random.randint(0, img_width - self.width)
y = random.randint(0, img_height - self.height)
img = F.crop(img, y, x, self.height, self.width)
kwargs['mask'] = F.crop(kwargs['mask'], y, x, self.height, self.width)
return img, kwargs
@TRANSFORM_REGISTRY.register()
class ToTensor(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, pic):
return F.to_tensor(pic)
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
@TRANSFORM_REGISTRY.register()
class PairRandomHorizontalFlip(torch.nn.Module):
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img, **kwargs):
if torch.rand(1) < self.p:
kwargs['mask'] = F.hflip(kwargs['mask'])
return F.hflip(img), kwargs
return img, kwargs
@TRANSFORM_REGISTRY.register()
class PairResize(nn.Module):
def __init__(self, size):
super().__init__()
self.resize = Resize(size=size)
def forward(self, img, **kwargs):
kwargs['mask'] = self.resize(kwargs['mask'])
img = self.resize(img)
return img, kwargs
class PairCompose(nn.Module):
def __init__(self, transforms):
super().__init__()
self.transforms = transforms
def __call__(self, img, **kwargs):
for t in self.transforms:
if len(inspect.signature(t.forward).parameters
) == 1: # count how many args, not count self
img = t(img)
else:
img, kwargs = t(img, **kwargs)
return img, kwargs
def __repr__(self) -> str:
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += f' {t}'
format_string += '\n)'
return format_string
@TRANSFORM_REGISTRY.register()
class HumanResizeCropFinalV3(nn.Module):
def __init__(self, size, crop_p=0.5):
super().__init__()
self.size = size
self.crop_p = crop_p
self.random_crop = RandomCrop(size=size)
self.paired_random_crop = PairRandomCrop(size=size)
def forward(self, img, **kwargs):
# step 1: short edge resize to 512
img = F.resize(img, size=self.size)
if 'mask' in kwargs:
kwargs['mask'] = F.resize(kwargs['mask'], size=self.size)
# step 2: random crop
width, height = img.size
if random.random() < self.crop_p:
if height > width:
crop_pos = random.randint(0, height - width)
img = F.crop(img, 0, 0, width + crop_pos, width)
if 'mask' in kwargs:
kwargs['mask'] = F.crop(kwargs['mask'], 0, 0, width + crop_pos, width)
else:
if 'mask' in kwargs:
img, kwargs = self.paired_random_crop(img, **kwargs)
else:
img = self.random_crop(img)
else:
img = img
# step 3: long edge resize
img = F.resize(img, size=self.size - 1, max_size=self.size)
if 'mask' in kwargs:
kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size)
new_width, new_height = img.size
img = np.array(img)
if 'mask' in kwargs:
kwargs['mask'] = np.array(kwargs['mask']) / 255
new_width = min(new_width, kwargs['mask'].shape[1])
new_height = min(new_height, kwargs['mask'].shape[0])
start_y = random.randint(0, 512 - new_height)
start_x = random.randint(0, 512 - new_width)
res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8)
res_mask = np.zeros((self.size, self.size))
res_img_mask = np.zeros((self.size, self.size))
res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img[:new_height, :new_width]
if 'mask' in kwargs:
res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask'][:new_height, :new_width]
kwargs['mask'] = res_mask
res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1
kwargs['img_mask'] = res_img_mask
img = Image.fromarray(res_img)
if 'mask' in kwargs:
kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
kwargs['mask'] = torch.from_numpy(kwargs['mask'])
kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask'])
return img, kwargs
@TRANSFORM_REGISTRY.register()
class ResizeFillMaskNew(nn.Module):
def __init__(self, size, crop_p, scale_ratio):
super().__init__()
self.size = size
self.crop_p = crop_p
self.scale_ratio = scale_ratio
self.random_crop = RandomCrop(size=size)
self.paired_random_crop = PairRandomCrop(size=size)
def forward(self, img, **kwargs):
# width, height = img.size
# step 1: short edge resize to 512
img = F.resize(img, size=self.size)
if 'mask' in kwargs:
kwargs['mask'] = F.resize(kwargs['mask'], size=self.size)
# step 2: random crop
if random.random() < self.crop_p:
if 'mask' in kwargs:
img, kwargs = self.paired_random_crop(img, **kwargs) # 51
else:
img = self.random_crop(img) # 512
else:
# long edge resize
img = F.resize(img, size=self.size - 1, max_size=self.size)
if 'mask' in kwargs:
kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size)
# step 3: random aspect ratio
width, height = img.size
ratio = random.uniform(*self.scale_ratio)
img = F.resize(img, size=(int(height * ratio), int(width * ratio)))
if 'mask' in kwargs:
kwargs['mask'] = F.resize(kwargs['mask'], size=(int(height * ratio), int(width * ratio)), interpolation=0)
# step 4: random place
new_width, new_height = img.size
img = np.array(img)
if 'mask' in kwargs:
kwargs['mask'] = np.array(kwargs['mask']) / 255
start_y = random.randint(0, 512 - new_height)
start_x = random.randint(0, 512 - new_width)
res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8)
res_mask = np.zeros((self.size, self.size))
res_img_mask = np.zeros((self.size, self.size))
res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img
if 'mask' in kwargs:
res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask']
kwargs['mask'] = res_mask
res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1
kwargs['img_mask'] = res_img_mask
img = Image.fromarray(res_img)
if 'mask' in kwargs:
kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
kwargs['mask'] = torch.from_numpy(kwargs['mask'])
kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask'])
return img, kwargs
@TRANSFORM_REGISTRY.register()
class ShuffleCaption(nn.Module):
def __init__(self, keep_token_num):
super().__init__()
self.keep_token_num = keep_token_num
def forward(self, img, **kwargs):
prompts = kwargs['prompts'].strip()
fixed_tokens = []
flex_tokens = [t.strip() for t in prompts.strip().split(',')]
if self.keep_token_num > 0:
fixed_tokens = flex_tokens[:self.keep_token_num]
flex_tokens = flex_tokens[self.keep_token_num:]
random.shuffle(flex_tokens)
prompts = ', '.join(fixed_tokens + flex_tokens)
kwargs['prompts'] = prompts
return img, kwargs
@TRANSFORM_REGISTRY.register()
class EnhanceText(nn.Module):
def __init__(self, enhance_type='object'):
super().__init__()
STYLE_TEMPLATE = [
'a painting in the style of {}',
'a rendering in the style of {}',
'a cropped painting in the style of {}',
'the painting in the style of {}',
'a clean painting in the style of {}',
'a dirty painting in the style of {}',
'a dark painting in the style of {}',
'a picture in the style of {}',
'a cool painting in the style of {}',
'a close-up painting in the style of {}',
'a bright painting in the style of {}',
'a cropped painting in the style of {}',
'a good painting in the style of {}',
'a close-up painting in the style of {}',
'a rendition in the style of {}',
'a nice painting in the style of {}',
'a small painting in the style of {}',
'a weird painting in the style of {}',
'a large painting in the style of {}',
]
OBJECT_TEMPLATE = [
'a photo of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'a photo of a clean {}',
'a photo of a dirty {}',
'a dark photo of the {}',
'a photo of my {}',
'a photo of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'a photo of the {}',
'a good photo of the {}',
'a photo of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'a photo of the clean {}',
'a rendition of a {}',
'a photo of a nice {}',
'a good photo of a {}',
'a photo of the nice {}',
'a photo of the small {}',
'a photo of the weird {}',
'a photo of the large {}',
'a photo of a cool {}',
'a photo of a small {}',
]
HUMAN_TEMPLATE = [
'a photo of a {}', 'a photo of one {}', 'a photo of the {}',
'the photo of a {}', 'a rendering of a {}',
'a rendition of the {}', 'a rendition of a {}',
'a cropped photo of the {}', 'a cropped photo of a {}',
'a bad photo of the {}', 'a bad photo of a {}',
'a photo of a weird {}', 'a weird photo of a {}',
'a bright photo of the {}', 'a good photo of the {}',
'a photo of a nice {}', 'a good photo of a {}',
'a photo of a cool {}', 'a bright photo of the {}'
]
if enhance_type == 'object':
self.templates = OBJECT_TEMPLATE
elif enhance_type == 'style':
self.templates = STYLE_TEMPLATE
elif enhance_type == 'human':
self.templates = HUMAN_TEMPLATE
else:
raise NotImplementedError
def forward(self, img, **kwargs):
concept_token = kwargs['prompts'].strip()
kwargs['prompts'] = random.choice(self.templates).format(concept_token)
return img, kwargs