Spaces:
Sleeping
Sleeping
import math | |
import torch | |
from torch.nn import functional as F | |
def translate_mat(t_x, t_y): | |
batch = t_x.shape[0] | |
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) | |
translate = torch.stack((t_x, t_y), 1) | |
mat[:, :2, 2] = translate | |
return mat | |
def rotate_mat(theta): | |
batch = theta.shape[0] | |
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) | |
sin_t = torch.sin(theta) | |
cos_t = torch.cos(theta) | |
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2) | |
mat[:, :2, :2] = rot | |
return mat | |
def scale_mat(s_x, s_y): | |
batch = s_x.shape[0] | |
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) | |
mat[:, 0, 0] = s_x | |
mat[:, 1, 1] = s_y | |
return mat | |
def lognormal_sample(size, mean=0, std=1): | |
return torch.empty(size).log_normal_(mean=mean, std=std) | |
def category_sample(size, categories): | |
category = torch.tensor(categories) | |
sample = torch.randint(high=len(categories), size=(size,)) | |
return category[sample] | |
def uniform_sample(size, low, high): | |
return torch.empty(size).uniform_(low, high) | |
def normal_sample(size, mean=0, std=1): | |
return torch.empty(size).normal_(mean, std) | |
def bernoulli_sample(size, p): | |
return torch.empty(size).bernoulli_(p) | |
def random_affine_apply(p, transform, prev, eye): | |
size = transform.shape[0] | |
select = bernoulli_sample(size, p).view(size, 1, 1) | |
select_transform = select * transform + (1 - select) * eye | |
return select_transform @ prev | |
def sample_affine(p, size, height, width): | |
G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1) | |
eye = G | |
# flip | |
param = category_sample(size, (0, 1)) | |
Gc = scale_mat(1 - 2.0 * param, torch.ones(size)) | |
G = random_affine_apply(p, Gc, G, eye) | |
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n') | |
# 90 rotate | |
param = category_sample(size, (0, 3)) | |
Gc = rotate_mat(-math.pi / 2 * param) | |
G = random_affine_apply(p, Gc, G, eye) | |
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n') | |
# integer translate | |
param = uniform_sample(size, -0.125, 0.125) | |
param_height = torch.round(param * height) / height | |
param_width = torch.round(param * width) / width | |
Gc = translate_mat(param_width, param_height) | |
G = random_affine_apply(p, Gc, G, eye) | |
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n') | |
# isotropic scale | |
param = lognormal_sample(size, std=0.2 * math.log(2)) | |
Gc = scale_mat(param, param) | |
G = random_affine_apply(p, Gc, G, eye) | |
# print('isotropic scale', G, scale_mat(param, param), sep='\n') | |
p_rot = 1 - math.sqrt(1 - p) | |
# pre-rotate | |
param = uniform_sample(size, -math.pi, math.pi) | |
Gc = rotate_mat(-param) | |
G = random_affine_apply(p_rot, Gc, G, eye) | |
# print('pre-rotate', G, rotate_mat(-param), sep='\n') | |
# anisotropic scale | |
param = lognormal_sample(size, std=0.2 * math.log(2)) | |
Gc = scale_mat(param, 1 / param) | |
G = random_affine_apply(p, Gc, G, eye) | |
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n') | |
# post-rotate | |
param = uniform_sample(size, -math.pi, math.pi) | |
Gc = rotate_mat(-param) | |
G = random_affine_apply(p_rot, Gc, G, eye) | |
# print('post-rotate', G, rotate_mat(-param), sep='\n') | |
# fractional translate | |
param = normal_sample(size, std=0.125) | |
Gc = translate_mat(param, param) | |
G = random_affine_apply(p, Gc, G, eye) | |
# print('fractional translate', G, translate_mat(param, param), sep='\n') | |
return G | |
def apply_affine(img, G): | |
grid = F.affine_grid( | |
torch.inverse(G).to(img)[:, :2, :], img.shape, align_corners=False | |
) | |
img_affine = F.grid_sample( | |
img, grid, mode="bilinear", align_corners=False, padding_mode="reflection" | |
) | |
return img_affine | |