Spaces:
Sleeping
Sleeping
# Adapted from DiT | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# DiT: https://github.com/facebookresearch/DiT | |
# -------------------------------------------------------- | |
import numpy as np | |
import torchvision.transforms as transforms | |
from PIL import Image | |
def center_crop_arr(pil_image, image_size): | |
""" | |
Center cropping implementation from ADM. | |
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 | |
""" | |
while min(*pil_image.size) >= 2 * image_size: | |
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) | |
scale = image_size / min(*pil_image.size) | |
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) | |
arr = np.array(pil_image) | |
crop_y = (arr.shape[0] - image_size) // 2 | |
crop_x = (arr.shape[1] - image_size) // 2 | |
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) | |
def get_transforms_image(image_size=256): | |
transform = transforms.Compose( | |
[ | |
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
return transform | |