PaintTransformer / train /models /painter_model.py
akhaliq3
spaces demo
035e10c
import torch
import numpy as np
from .base_model import BaseModel
from . import networks
from util import morphology
from scipy.optimize import linear_sum_assignment
from PIL import Image
class PainterModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
parser.set_defaults(dataset_mode='null')
parser.add_argument('--used_strokes', type=int, default=8,
help='actually generated strokes number')
parser.add_argument('--num_blocks', type=int, default=3,
help='number of transformer blocks for stroke generator')
parser.add_argument('--lambda_w', type=float, default=10.0, help='weight for w loss of stroke shape')
parser.add_argument('--lambda_pixel', type=float, default=10.0, help='weight for pixel-level L1 loss')
parser.add_argument('--lambda_gt', type=float, default=1.0, help='weight for ground-truth loss')
parser.add_argument('--lambda_decision', type=float, default=10.0, help='weight for stroke decision loss')
parser.add_argument('--lambda_recall', type=float, default=10.0, help='weight of recall for stroke decision loss')
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['pixel', 'gt', 'w', 'decision']
self.visual_names = ['old', 'render', 'rec']
self.model_names = ['g']
self.d = 12 # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A
self.d_shape = 5
def read_img(img_path, img_type='RGB'):
img = Image.open(img_path).convert(img_type)
img = np.array(img)
if img.ndim == 2:
img = np.expand_dims(img, axis=-1)
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img).unsqueeze(0).float() / 255.
return img
brush_large_vertical = read_img('brush/brush_large_vertical.png', 'L').to(self.device)
brush_large_horizontal = read_img('brush/brush_large_horizontal.png', 'L').to(self.device)
self.meta_brushes = torch.cat(
[brush_large_vertical, brush_large_horizontal], dim=0)
net_g = networks.Painter(self.d_shape, opt.used_strokes, opt.ngf,
n_enc_layers=opt.num_blocks, n_dec_layers=opt.num_blocks)
self.net_g = networks.init_net(net_g, opt.init_type, opt.init_gain, self.gpu_ids)
self.old = None
self.render = None
self.rec = None
self.gt_param = None
self.pred_param = None
self.gt_decision = None
self.pred_decision = None
self.patch_size = 32
self.loss_pixel = torch.tensor(0., device=self.device)
self.loss_gt = torch.tensor(0., device=self.device)
self.loss_w = torch.tensor(0., device=self.device)
self.loss_decision = torch.tensor(0., device=self.device)
self.criterion_pixel = torch.nn.L1Loss().to(self.device)
self.criterion_decision = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(opt.lambda_recall)).to(self.device)
if self.isTrain:
self.optimizer = torch.optim.Adam(self.net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer)
def param2stroke(self, param, H, W):
# param: b, 12
b = param.shape[0]
param_list = torch.split(param, 1, dim=1)
x0, y0, w, h, theta = [item.squeeze(-1) for item in param_list[:5]]
R0, G0, B0, R2, G2, B2, _ = param_list[5:]
sin_theta = torch.sin(torch.acos(torch.tensor(-1., device=param.device)) * theta)
cos_theta = torch.cos(torch.acos(torch.tensor(-1., device=param.device)) * theta)
index = torch.full((b,), -1, device=param.device)
index[h > w] = 0
index[h <= w] = 1
brush = self.meta_brushes[index.long()]
alphas = torch.cat([brush, brush, brush], dim=1)
alphas = (alphas > 0).float()
t = torch.arange(0, brush.shape[2], device=param.device).unsqueeze(0) / brush.shape[2]
color_map = torch.stack([R0 * (1 - t) + R2 * t, G0 * (1 - t) + G2 * t, B0 * (1 - t) + B2 * t], dim=1)
color_map = color_map.unsqueeze(-1).repeat(1, 1, 1, brush.shape[3])
brush = brush * color_map
warp_00 = cos_theta / w
warp_01 = sin_theta * H / (W * w)
warp_02 = (1 - 2 * x0) * cos_theta / w + (1 - 2 * y0) * sin_theta * H / (W * w)
warp_10 = -sin_theta * W / (H * h)
warp_11 = cos_theta / h
warp_12 = (1 - 2 * y0) * cos_theta / h - (1 - 2 * x0) * sin_theta * W / (H * h)
warp_0 = torch.stack([warp_00, warp_01, warp_02], dim=1)
warp_1 = torch.stack([warp_10, warp_11, warp_12], dim=1)
warp = torch.stack([warp_0, warp_1], dim=1)
grid = torch.nn.functional.affine_grid(warp, torch.Size((b, 3, H, W)), align_corners=False)
brush = torch.nn.functional.grid_sample(brush, grid, align_corners=False)
alphas = torch.nn.functional.grid_sample(alphas, grid, align_corners=False)
return brush, alphas
def set_input(self, input_dict):
self.image_paths = input_dict['A_paths']
with torch.no_grad():
old_param = torch.rand(self.opt.batch_size // 4, self.opt.used_strokes, self.d, device=self.device)
old_param[:, :, :4] = old_param[:, :, :4] * 0.5 + 0.2
old_param[:, :, -4:-1] = old_param[:, :, -7:-4]
old_param = old_param.view(-1, self.d).contiguous()
foregrounds, alphas = self.param2stroke(old_param, self.patch_size * 2, self.patch_size * 2)
foregrounds = morphology.Dilation2d(m=1)(foregrounds)
alphas = morphology.Erosion2d(m=1)(alphas)
foregrounds = foregrounds.view(self.opt.batch_size // 4, self.opt.used_strokes, 3, self.patch_size * 2,
self.patch_size * 2).contiguous()
alphas = alphas.view(self.opt.batch_size // 4, self.opt.used_strokes, 3, self.patch_size * 2,
self.patch_size * 2).contiguous()
old = torch.zeros(self.opt.batch_size // 4, 3, self.patch_size * 2, self.patch_size * 2, device=self.device)
for i in range(self.opt.used_strokes):
foreground = foregrounds[:, i, :, :, :]
alpha = alphas[:, i, :, :, :]
old = foreground * alpha + old * (1 - alpha)
old = old.view(self.opt.batch_size // 4, 3, 2, self.patch_size, 2, self.patch_size).contiguous()
old = old.permute(0, 2, 4, 1, 3, 5).contiguous()
self.old = old.view(self.opt.batch_size, 3, self.patch_size, self.patch_size).contiguous()
gt_param = torch.rand(self.opt.batch_size, self.opt.used_strokes, self.d, device=self.device)
gt_param[:, :, :4] = gt_param[:, :, :4] * 0.5 + 0.2
gt_param[:, :, -4:-1] = gt_param[:, :, -7:-4]
self.gt_param = gt_param[:, :, :self.d_shape]
gt_param = gt_param.view(-1, self.d).contiguous()
foregrounds, alphas = self.param2stroke(gt_param, self.patch_size, self.patch_size)
foregrounds = morphology.Dilation2d(m=1)(foregrounds)
alphas = morphology.Erosion2d(m=1)(alphas)
foregrounds = foregrounds.view(self.opt.batch_size, self.opt.used_strokes, 3, self.patch_size,
self.patch_size).contiguous()
alphas = alphas.view(self.opt.batch_size, self.opt.used_strokes, 3, self.patch_size,
self.patch_size).contiguous()
self.render = self.old.clone()
gt_decision = torch.ones(self.opt.batch_size, self.opt.used_strokes, device=self.device)
for i in range(self.opt.used_strokes):
foreground = foregrounds[:, i, :, :, :]
alpha = alphas[:, i, :, :, :]
for j in range(i):
iou = (torch.sum(alpha * alphas[:, j, :, :, :], dim=(-3, -2, -1)) + 1e-5) / (
torch.sum(alphas[:, j, :, :, :], dim=(-3, -2, -1)) + 1e-5)
gt_decision[:, i] = ((iou < 0.75) | (~gt_decision[:, j].bool())).float() * gt_decision[:, i]
decision = gt_decision[:, i].view(self.opt.batch_size, 1, 1, 1).contiguous()
self.render = foreground * alpha * decision + self.render * (1 - alpha * decision)
self.gt_decision = gt_decision
def forward(self):
param, decisions = self.net_g(self.render, self.old)
# stroke_param: b, stroke_per_patch, param_per_stroke
# decision: b, stroke_per_patch, 1
self.pred_decision = decisions.view(-1, self.opt.used_strokes).contiguous()
self.pred_param = param[:, :, :self.d_shape]
param = param.view(-1, self.d).contiguous()
foregrounds, alphas = self.param2stroke(param, self.patch_size, self.patch_size)
foregrounds = morphology.Dilation2d(m=1)(foregrounds)
alphas = morphology.Erosion2d(m=1)(alphas)
# foreground, alpha: b * stroke_per_patch, 3, output_size, output_size
foregrounds = foregrounds.view(-1, self.opt.used_strokes, 3, self.patch_size, self.patch_size)
alphas = alphas.view(-1, self.opt.used_strokes, 3, self.patch_size, self.patch_size)
# foreground, alpha: b, stroke_per_patch, 3, output_size, output_size
decisions = networks.SignWithSigmoidGrad.apply(decisions.view(-1, self.opt.used_strokes, 1, 1, 1).contiguous())
self.rec = self.old.clone()
for j in range(foregrounds.shape[1]):
foreground = foregrounds[:, j, :, :, :]
alpha = alphas[:, j, :, :, :]
decision = decisions[:, j, :, :, :]
self.rec = foreground * alpha * decision + self.rec * (1 - alpha * decision)
@staticmethod
def get_sigma_sqrt(w, h, theta):
sigma_00 = w * (torch.cos(theta) ** 2) / 2 + h * (torch.sin(theta) ** 2) / 2
sigma_01 = (w - h) * torch.cos(theta) * torch.sin(theta) / 2
sigma_11 = h * (torch.cos(theta) ** 2) / 2 + w * (torch.sin(theta) ** 2) / 2
sigma_0 = torch.stack([sigma_00, sigma_01], dim=-1)
sigma_1 = torch.stack([sigma_01, sigma_11], dim=-1)
sigma = torch.stack([sigma_0, sigma_1], dim=-2)
return sigma
@staticmethod
def get_sigma(w, h, theta):
sigma_00 = w * w * (torch.cos(theta) ** 2) / 4 + h * h * (torch.sin(theta) ** 2) / 4
sigma_01 = (w * w - h * h) * torch.cos(theta) * torch.sin(theta) / 4
sigma_11 = h * h * (torch.cos(theta) ** 2) / 4 + w * w * (torch.sin(theta) ** 2) / 4
sigma_0 = torch.stack([sigma_00, sigma_01], dim=-1)
sigma_1 = torch.stack([sigma_01, sigma_11], dim=-1)
sigma = torch.stack([sigma_0, sigma_1], dim=-2)
return sigma
def gaussian_w_distance(self, param_1, param_2):
mu_1, w_1, h_1, theta_1 = torch.split(param_1, (2, 1, 1, 1), dim=-1)
w_1 = w_1.squeeze(-1)
h_1 = h_1.squeeze(-1)
theta_1 = torch.acos(torch.tensor(-1., device=param_1.device)) * theta_1.squeeze(-1)
trace_1 = (w_1 ** 2 + h_1 ** 2) / 4
mu_2, w_2, h_2, theta_2 = torch.split(param_2, (2, 1, 1, 1), dim=-1)
w_2 = w_2.squeeze(-1)
h_2 = h_2.squeeze(-1)
theta_2 = torch.acos(torch.tensor(-1., device=param_2.device)) * theta_2.squeeze(-1)
trace_2 = (w_2 ** 2 + h_2 ** 2) / 4
sigma_1_sqrt = self.get_sigma_sqrt(w_1, h_1, theta_1)
sigma_2 = self.get_sigma(w_2, h_2, theta_2)
trace_12 = torch.matmul(torch.matmul(sigma_1_sqrt, sigma_2), sigma_1_sqrt)
trace_12 = torch.sqrt(trace_12[..., 0, 0] + trace_12[..., 1, 1] + 2 * torch.sqrt(
trace_12[..., 0, 0] * trace_12[..., 1, 1] - trace_12[..., 0, 1] * trace_12[..., 1, 0]))
return torch.sum((mu_1 - mu_2) ** 2, dim=-1) + trace_1 + trace_2 - 2 * trace_12
def optimize_parameters(self):
self.forward()
self.loss_pixel = self.criterion_pixel(self.rec, self.render) * self.opt.lambda_pixel
cur_valid_gt_size = 0
with torch.no_grad():
r_idx = []
c_idx = []
for i in range(self.gt_param.shape[0]):
is_valid_gt = self.gt_decision[i].bool()
valid_gt_param = self.gt_param[i, is_valid_gt]
cost_matrix_l1 = torch.cdist(self.pred_param[i], valid_gt_param, p=1)
pred_param_broad = self.pred_param[i].unsqueeze(1).contiguous().repeat(
1, valid_gt_param.shape[0], 1)
valid_gt_param_broad = valid_gt_param.unsqueeze(0).contiguous().repeat(
self.pred_param.shape[1], 1, 1)
cost_matrix_w = self.gaussian_w_distance(pred_param_broad, valid_gt_param_broad)
decision = self.pred_decision[i]
cost_matrix_decision = (1 - decision).unsqueeze(-1).repeat(1, valid_gt_param.shape[0])
r, c = linear_sum_assignment((cost_matrix_l1 + cost_matrix_w + cost_matrix_decision).cpu())
r_idx.append(torch.tensor(r + self.pred_param.shape[1] * i, device=self.device))
c_idx.append(torch.tensor(c + cur_valid_gt_size, device=self.device))
cur_valid_gt_size += valid_gt_param.shape[0]
r_idx = torch.cat(r_idx, dim=0)
c_idx = torch.cat(c_idx, dim=0)
paired_gt_decision = torch.zeros(self.gt_decision.shape[0] * self.gt_decision.shape[1], device=self.device)
paired_gt_decision[r_idx] = 1.
all_valid_gt_param = self.gt_param[self.gt_decision.bool(), :]
all_pred_param = self.pred_param.view(-1, self.pred_param.shape[2]).contiguous()
all_pred_decision = self.pred_decision.view(-1).contiguous()
paired_gt_param = all_valid_gt_param[c_idx, :]
paired_pred_param = all_pred_param[r_idx, :]
self.loss_gt = self.criterion_pixel(paired_pred_param, paired_gt_param) * self.opt.lambda_gt
self.loss_w = self.gaussian_w_distance(paired_pred_param, paired_gt_param).mean() * self.opt.lambda_w
self.loss_decision = self.criterion_decision(all_pred_decision, paired_gt_decision) * self.opt.lambda_decision
loss = self.loss_pixel + self.loss_gt + self.loss_w + self.loss_decision
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()