|
import importlib |
|
import PIL |
|
import pytorch_lightning as pl |
|
import torch.utils.data |
|
import wandb |
|
from typing import Union |
|
from torchvision import transforms |
|
from utils_.loss import VGGPerceptualLoss |
|
from utils_.visualization import * |
|
import torch.nn.functional as F |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class Model(pl.LightningModule): |
|
def __init__(self, **kwargs): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.encoder = importlib.import_module('models.' + self.hparams.encoder).Encoder(self.hparams) |
|
self.decoder = importlib.import_module('models.' + self.hparams.decoder).Decoder(self.hparams) |
|
self.batch_size = self.hparams.batch_size |
|
|
|
self.vgg_loss = VGGPerceptualLoss() |
|
|
|
self.transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(0.5, 0.5) |
|
]) |
|
|
|
def forward(self, x: PIL.Image.Image) -> PIL.Image.Image: |
|
""" |
|
:param x: a PIL image |
|
:return: an edge map of the same size as x with values in [0, 1] (normalized by max) |
|
""" |
|
w, h = x.size |
|
x = self.transform(x).unsqueeze(0) |
|
x = x.to(self.device) |
|
kp = self.encoder({'img': x})['keypoints'] |
|
edge_map = self.decoder.rasterize(kp, output_size=64) |
|
bs = edge_map.shape[0] |
|
edge_map = edge_map / (1e-8 + edge_map.reshape(bs, 1, -1).max(dim=2, keepdim=True)[0].reshape(bs, 1, 1, 1)) |
|
edge_map = torch.cat([edge_map] * 3, dim=1) |
|
edge_map = F.interpolate(edge_map, size=(h, w), mode='bilinear', align_corners=False) |
|
x = torch.clamp(edge_map + (x * 0.5 + 0.5)*0.5, min=0, max=1) |
|
x = transforms.ToPILImage()(x[0].detach().cpu()) |
|
|
|
fig = plt.figure(figsize=(1, h/w), dpi=w) |
|
fig.tight_layout(pad=0) |
|
plt.axis('off') |
|
plt.imshow(x) |
|
kp = kp[0].detach().cpu() * 0.5 + 0.5 |
|
kp[:, 1] *= w |
|
kp[:, 0] *= h |
|
plt.scatter(kp[:, 1], kp[:, 0], s=min(w/h, min(1, h/w)), marker='o') |
|
ncols, nrows = fig.canvas.get_width_height() |
|
fig.canvas.draw() |
|
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3) |
|
plt.close(fig) |
|
return plot |
|
|