AutoLink / models /model.py
xingzhehe's picture
try fitst commit
91fc62a
import importlib
import PIL
import pytorch_lightning as pl
import torch.utils.data
import wandb
from typing import Union
from torchvision import transforms
from utils_.loss import VGGPerceptualLoss
from utils_.visualization import *
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Model(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
self.encoder = importlib.import_module('models.' + self.hparams.encoder).Encoder(self.hparams)
self.decoder = importlib.import_module('models.' + self.hparams.decoder).Decoder(self.hparams)
self.batch_size = self.hparams.batch_size
self.vgg_loss = VGGPerceptualLoss()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])
def forward(self, x: PIL.Image.Image) -> PIL.Image.Image:
"""
:param x: a PIL image
:return: an edge map of the same size as x with values in [0, 1] (normalized by max)
"""
w, h = x.size
x = self.transform(x).unsqueeze(0)
x = x.to(self.device)
kp = self.encoder({'img': x})['keypoints']
edge_map = self.decoder.rasterize(kp, output_size=64)
bs = edge_map.shape[0]
edge_map = edge_map / (1e-8 + edge_map.reshape(bs, 1, -1).max(dim=2, keepdim=True)[0].reshape(bs, 1, 1, 1))
edge_map = torch.cat([edge_map] * 3, dim=1)
edge_map = F.interpolate(edge_map, size=(h, w), mode='bilinear', align_corners=False)
x = torch.clamp(edge_map + (x * 0.5 + 0.5)*0.5, min=0, max=1)
x = transforms.ToPILImage()(x[0].detach().cpu())
fig = plt.figure(figsize=(1, h/w), dpi=w)
fig.tight_layout(pad=0)
plt.axis('off')
plt.imshow(x)
kp = kp[0].detach().cpu() * 0.5 + 0.5
kp[:, 1] *= w
kp[:, 0] *= h
plt.scatter(kp[:, 1], kp[:, 0], s=min(w/h, min(1, h/w)), marker='o')
ncols, nrows = fig.canvas.get_width_height()
fig.canvas.draw()
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
plt.close(fig)
return plot