|
from typing import List, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from omegaconf import ListConfig |
|
|
|
from torchvision.utils import save_image |
|
from ...util import append_dims, instantiate_from_config |
|
|
|
|
|
class StandardDiffusionLoss(nn.Module): |
|
def __init__( |
|
self, |
|
sigma_sampler_config, |
|
type="l2", |
|
offset_noise_level=0.0, |
|
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, |
|
): |
|
super().__init__() |
|
|
|
assert type in ["l2", "l1", "lpips"] |
|
|
|
self.sigma_sampler = instantiate_from_config(sigma_sampler_config) |
|
|
|
self.type = type |
|
self.offset_noise_level = offset_noise_level |
|
|
|
|
|
|
|
|
|
if not batch2model_keys: |
|
batch2model_keys = [] |
|
|
|
if isinstance(batch2model_keys, str): |
|
batch2model_keys = [batch2model_keys] |
|
|
|
self.batch2model_keys = set(batch2model_keys) |
|
|
|
def __call__(self, network, denoiser, conditioner, input, batch, *args, **kwarg): |
|
cond = conditioner(batch) |
|
additional_model_inputs = { |
|
key: batch[key] for key in self.batch2model_keys.intersection(batch) |
|
} |
|
|
|
sigmas = self.sigma_sampler(input.shape[0]).to(input.device) |
|
noise = torch.randn_like(input) |
|
if self.offset_noise_level > 0.0: |
|
noise = noise + self.offset_noise_level * append_dims( |
|
torch.randn(input.shape[0], device=input.device), input.ndim |
|
) |
|
noised_input = input + noise * append_dims(sigmas, input.ndim) |
|
model_output = denoiser( |
|
network, noised_input, sigmas, cond, **additional_model_inputs |
|
) |
|
w = append_dims(denoiser.w(sigmas), input.ndim) |
|
|
|
loss = self.get_diff_loss(model_output, input, w) |
|
loss = loss.mean() |
|
loss_dict = {"loss": loss} |
|
|
|
return loss, loss_dict |
|
|
|
def get_diff_loss(self, model_output, target, w): |
|
if self.type == "l2": |
|
return torch.mean( |
|
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 |
|
) |
|
elif self.type == "l1": |
|
return torch.mean( |
|
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 |
|
) |
|
elif self.type == "lpips": |
|
loss = self.lpips(model_output, target).reshape(-1) |
|
return loss |
|
|
|
|
|
class FullLoss(StandardDiffusionLoss): |
|
|
|
def __init__( |
|
self, |
|
seq_len=12, |
|
kernel_size=3, |
|
gaussian_sigma=0.5, |
|
min_attn_size=16, |
|
lambda_local_loss=0.0, |
|
lambda_ocr_loss=0.0, |
|
ocr_enabled = False, |
|
predictor_config = None, |
|
*args, **kwarg |
|
): |
|
super().__init__(*args, **kwarg) |
|
|
|
self.gaussian_kernel_size = kernel_size |
|
gaussian_kernel = self.get_gaussian_kernel(kernel_size=self.gaussian_kernel_size, sigma=gaussian_sigma, out_channels=seq_len) |
|
self.register_buffer("g_kernel", gaussian_kernel.requires_grad_(False)) |
|
|
|
self.min_attn_size = min_attn_size |
|
self.lambda_local_loss = lambda_local_loss |
|
self.lambda_ocr_loss = lambda_ocr_loss |
|
|
|
self.ocr_enabled = ocr_enabled |
|
if ocr_enabled: |
|
self.predictor = instantiate_from_config(predictor_config) |
|
|
|
def get_gaussian_kernel(self, kernel_size=3, sigma=1, out_channels=3): |
|
|
|
x_coord = torch.arange(kernel_size) |
|
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) |
|
y_grid = x_grid.t() |
|
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() |
|
|
|
mean = (kernel_size - 1)/2. |
|
variance = sigma**2. |
|
|
|
|
|
|
|
|
|
gaussian_kernel = (1./(2.*torch.pi*variance)) *\ |
|
torch.exp( |
|
-torch.sum((xy_grid - mean)**2., dim=-1) /\ |
|
(2*variance) |
|
) |
|
|
|
|
|
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) |
|
|
|
|
|
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) |
|
gaussian_kernel = gaussian_kernel.tile(out_channels, 1, 1, 1) |
|
|
|
return gaussian_kernel |
|
|
|
def __call__(self, network, denoiser, conditioner, input, batch, first_stage_model, scaler): |
|
|
|
cond = conditioner(batch) |
|
|
|
sigmas = self.sigma_sampler(input.shape[0]).to(input.device) |
|
noise = torch.randn_like(input) |
|
if self.offset_noise_level > 0.0: |
|
noise = noise + self.offset_noise_level * append_dims( |
|
torch.randn(input.shape[0], device=input.device), input.ndim |
|
) |
|
|
|
noised_input = input + noise * append_dims(sigmas, input.ndim) |
|
model_output = denoiser(network, noised_input, sigmas, cond) |
|
w = append_dims(denoiser.w(sigmas), input.ndim) |
|
|
|
diff_loss = self.get_diff_loss(model_output, input, w) |
|
local_loss = self.get_local_loss(network.diffusion_model.attn_map_cache, batch["seg"], batch["seg_mask"]) |
|
diff_loss = diff_loss.mean() |
|
local_loss = local_loss.mean() |
|
|
|
if self.ocr_enabled: |
|
ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler) |
|
ocr_loss = ocr_loss.mean() |
|
|
|
loss = diff_loss + self.lambda_local_loss * local_loss |
|
if self.ocr_enabled: |
|
loss += self.lambda_ocr_loss * ocr_loss |
|
|
|
loss_dict = { |
|
"loss/diff_loss": diff_loss, |
|
"loss/local_loss": local_loss, |
|
"loss/full_loss": loss |
|
} |
|
|
|
if self.ocr_enabled: |
|
loss_dict["loss/ocr_loss"] = ocr_loss |
|
|
|
return loss, loss_dict |
|
|
|
def get_ocr_loss(self, model_output, r_bbox, label, first_stage_model, scaler): |
|
|
|
model_output = 1 / scaler * model_output |
|
model_output_decoded = first_stage_model.decode(model_output) |
|
model_output_crops = [] |
|
|
|
for i, bbox in enumerate(r_bbox): |
|
m_top, m_bottom, m_left, m_right = bbox |
|
model_output_crops.append(model_output_decoded[i, :, m_top:m_bottom, m_left:m_right]) |
|
|
|
loss = self.predictor.calc_loss(model_output_crops, label) |
|
|
|
return loss |
|
|
|
def get_min_local_loss(self, attn_map_cache, mask, seg_mask): |
|
|
|
loss = 0 |
|
count = 0 |
|
|
|
for item in attn_map_cache: |
|
|
|
heads = item["heads"] |
|
size = item["size"] |
|
attn_map = item["attn_map"] |
|
|
|
if size < self.min_attn_size: continue |
|
|
|
seg_l = seg_mask.shape[1] |
|
|
|
bh, n, l = attn_map.shape |
|
attn_map = attn_map.reshape((-1, heads, n, l)) |
|
|
|
assert seg_l <= l |
|
attn_map = attn_map[..., :seg_l] |
|
attn_map = attn_map.permute(0, 1, 3, 2) |
|
attn_map = attn_map.mean(dim = 1) |
|
|
|
attn_map = attn_map.reshape((-1, seg_l, size, size)) |
|
attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) |
|
attn_map = attn_map.reshape((-1, seg_l, n)) |
|
|
|
mask_map = F.interpolate(mask, (size, size)) |
|
mask_map = mask_map.tile((1, seg_l, 1, 1)) |
|
mask_map = mask_map.reshape((-1, seg_l, n)) |
|
|
|
p_loss = (mask_map * attn_map).max(dim = -1)[0] |
|
p_loss = p_loss + (1 - seg_mask) |
|
p_loss = p_loss.min(dim = -1)[0] |
|
|
|
loss += -p_loss |
|
count += 1 |
|
|
|
loss = loss / count |
|
|
|
return loss |
|
|
|
def get_local_loss(self, attn_map_cache, seg, seg_mask): |
|
|
|
loss = 0 |
|
count = 0 |
|
|
|
for item in attn_map_cache: |
|
|
|
heads = item["heads"] |
|
size = item["size"] |
|
attn_map = item["attn_map"] |
|
|
|
if size < self.min_attn_size: continue |
|
|
|
seg_l = seg_mask.shape[1] |
|
|
|
bh, n, l = attn_map.shape |
|
attn_map = attn_map.reshape((-1, heads, n, l)) |
|
|
|
assert seg_l <= l |
|
attn_map = attn_map[..., :seg_l] |
|
attn_map = attn_map.permute(0, 1, 3, 2) |
|
attn_map = attn_map.mean(dim = 1) |
|
|
|
attn_map = attn_map.reshape((-1, seg_l, size, size)) |
|
attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) |
|
attn_map = attn_map.reshape((-1, seg_l, n)) |
|
|
|
seg_map = F.interpolate(seg, (size, size)) |
|
seg_map = seg_map.reshape((-1, seg_l, n)) |
|
n_seg_map = 1 - seg_map |
|
|
|
p_loss = (seg_map * attn_map).max(dim = -1)[0] |
|
n_loss = (n_seg_map * attn_map).max(dim = -1)[0] |
|
|
|
p_loss = p_loss * seg_mask |
|
n_loss = n_loss * seg_mask |
|
|
|
p_loss = p_loss.sum(dim = -1) / seg_mask.sum(dim = -1) |
|
n_loss = n_loss.sum(dim = -1) / seg_mask.sum(dim = -1) |
|
|
|
f_loss = n_loss - p_loss |
|
loss += f_loss |
|
count += 1 |
|
|
|
loss = loss / count |
|
|
|
return loss |