import importlib import PIL import pytorch_lightning as pl import torch.utils.data import wandb from typing import Union from torchvision import transforms from utils_.loss import VGGPerceptualLoss from utils_.visualization import * import torch.nn.functional as F import matplotlib.pyplot as plt class Model(pl.LightningModule): def __init__(self, **kwargs): super().__init__() self.save_hyperparameters() self.encoder = importlib.import_module('models.' + self.hparams.encoder).Encoder(self.hparams) self.decoder = importlib.import_module('models.' + self.hparams.decoder).Decoder(self.hparams) self.batch_size = self.hparams.batch_size self.vgg_loss = VGGPerceptualLoss() self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(0.5, 0.5) ]) def forward(self, x: PIL.Image.Image) -> PIL.Image.Image: """ :param x: a PIL image :return: an edge map of the same size as x with values in [0, 1] (normalized by max) """ w, h = x.size x = self.transform(x).unsqueeze(0) x = x.to(self.device) kp = self.encoder({'img': x})['keypoints'] edge_map = self.decoder.rasterize(kp, output_size=64) bs = edge_map.shape[0] edge_map = edge_map / (1e-8 + edge_map.reshape(bs, 1, -1).max(dim=2, keepdim=True)[0].reshape(bs, 1, 1, 1)) edge_map = torch.cat([edge_map] * 3, dim=1) edge_map = F.interpolate(edge_map, size=(h, w), mode='bilinear', align_corners=False) x = torch.clamp(edge_map + (x * 0.5 + 0.5)*0.5, min=0, max=1) x = transforms.ToPILImage()(x[0].detach().cpu()) fig = plt.figure(figsize=(1, h/w), dpi=w) fig.tight_layout(pad=0) plt.axis('off') plt.imshow(x) kp = kp[0].detach().cpu() * 0.5 + 0.5 kp[:, 1] *= w kp[:, 0] *= h plt.scatter(kp[:, 1], kp[:, 0], s=min(w/h, min(1, h/w)), marker='o') ncols, nrows = fig.canvas.get_width_height() fig.canvas.draw() plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) plt.close(fig) return plot