Time-TravelRephotography / losses /color_transfer_loss.py
feng2022's picture
commit
c92ff54
raw
history blame
1.57 kB
from typing import List, Optional
import torch
from torch import nn
from torch.nn.functional import (
smooth_l1_loss,
)
def flatten_CHW(im: torch.Tensor) -> torch.Tensor:
"""
(B, C, H, W) -> (B, -1)
"""
B = im.shape[0]
return im.reshape(B, -1)
def stddev(x: torch.Tensor) -> torch.Tensor:
"""
x: (B, -1), assume with mean normalized
Retuens:
stddev: (B)
"""
return torch.sqrt(torch.mean(x * x, dim=-1))
def gram_matrix(input_):
B, C = input_.shape[:2]
features = input_.view(B, C, -1)
N = features.shape[-1]
G = torch.bmm(features, features.transpose(1, 2)) # C x C
return G.div(C * N)
class ColorTransferLoss(nn.Module):
"""Penalize the gram matrix difference between StyleGAN2's ToRGB outputs"""
def __init__(
self,
init_rgbs,
scale_rgb: bool = False
):
super().__init__()
with torch.no_grad():
init_feats = [x.detach() for x in init_rgbs]
self.stds = [stddev(flatten_CHW(rgb)) if scale_rgb else 1 for rgb in init_feats] # (B, 1, 1, 1) or scalar
self.grams = [gram_matrix(rgb / std) for rgb, std in zip(init_feats, self.stds)]
def forward(self, rgbs: List[torch.Tensor], level: int = None):
if level is None:
level = len(self.grams)
feats = rgbs
loss = 0
for i, (rgb, std) in enumerate(zip(feats[:level], self.stds[:level])):
G = gram_matrix(rgb / std)
loss = loss + smooth_l1_loss(G, self.grams[i])
return loss