Spaces:
Sleeping
Sleeping
"""Various utilities used in the film_net frame interpolator model.""" | |
from typing import List, Optional | |
import cv2 | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
def pad_batch(batch, align): | |
height, width = batch.shape[1:3] | |
height_to_pad = (align - height % align) if height % align != 0 else 0 | |
width_to_pad = (align - width % align) if width % align != 0 else 0 | |
crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)] | |
batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), | |
(width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant') | |
return batch, crop_region | |
def load_image(path, align=64): | |
image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) | |
image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align) | |
return image_batch, crop_region | |
def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]: | |
"""Builds an image pyramid from a given image. | |
The original image is included in the pyramid and the rest are generated by | |
successively halving the resolution. | |
Args: | |
image: the input image. | |
options: film_net options object | |
Returns: | |
A list of images starting from the finest with options.pyramid_levels items | |
""" | |
pyramid = [] | |
for i in range(pyramid_levels): | |
pyramid.append(image) | |
if i < pyramid_levels - 1: | |
image = F.avg_pool2d(image, 2, 2) | |
return pyramid | |
def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: | |
"""Backward warps the image using the given flow. | |
Specifically, the output pixel in batch b, at position x, y will be computed | |
as follows: | |
(flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0]) | |
output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x) | |
Note that the flow vectors are expected as [x, y], e.g. x in position 0 and | |
y in position 1. | |
Args: | |
image: An image with shape BxHxWxC. | |
flow: A flow with shape BxHxWx2, with the two channels denoting the relative | |
offset in order: (dx, dy). | |
Returns: | |
A warped image. | |
""" | |
flow = -flow.flip(1) | |
dtype = flow.dtype | |
device = flow.device | |
# warped = tfa_image.dense_image_warp(image, flow) | |
# Same as above but with pytorch | |
ls1 = 1 - 1 / flow.shape[3] | |
ls2 = 1 - 1 / flow.shape[2] | |
normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor( | |
[flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None] | |
normalized_flow2 = torch.stack([ | |
torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1], | |
torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0], | |
], dim=3) | |
warped = F.grid_sample(image, normalized_flow2, | |
mode='bilinear', padding_mode='border', align_corners=False) | |
return warped.reshape(image.shape) | |
def multiply_pyramid(pyramid: List[torch.Tensor], | |
scalar: torch.Tensor) -> List[torch.Tensor]: | |
"""Multiplies all image batches in the pyramid by a batch of scalars. | |
Args: | |
pyramid: Pyramid of image batches. | |
scalar: Batch of scalars. | |
Returns: | |
An image pyramid with all images multiplied by the scalar. | |
""" | |
# To multiply each image with its corresponding scalar, we first transpose | |
# the batch of images from BxHxWxC-format to CxHxWxB. This can then be | |
# multiplied with a batch of scalars, then we transpose back to the standard | |
# BxHxWxC form. | |
return [image * scalar[..., None, None] for image in pyramid] | |
def flow_pyramid_synthesis( | |
residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: | |
"""Converts a residual flow pyramid into a flow pyramid.""" | |
flow = residual_pyramid[-1] | |
flow_pyramid: List[torch.Tensor] = [flow] | |
for residual_flow in residual_pyramid[:-1][::-1]: | |
level_size = residual_flow.shape[2:4] | |
flow = F.interpolate(2 * flow, size=level_size, mode='bilinear') | |
flow = residual_flow + flow | |
flow_pyramid.insert(0, flow) | |
return flow_pyramid | |
def pyramid_warp(feature_pyramid: List[torch.Tensor], | |
flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: | |
"""Warps the feature pyramid using the flow pyramid. | |
Args: | |
feature_pyramid: feature pyramid starting from the finest level. | |
flow_pyramid: flow fields, starting from the finest level. | |
Returns: | |
Reverse warped feature pyramid. | |
""" | |
warped_feature_pyramid = [] | |
for features, flow in zip(feature_pyramid, flow_pyramid): | |
warped_feature_pyramid.append(warp(features, flow)) | |
return warped_feature_pyramid | |
def concatenate_pyramids(pyramid1: List[torch.Tensor], | |
pyramid2: List[torch.Tensor]) -> List[torch.Tensor]: | |
"""Concatenates each pyramid level together in the channel dimension.""" | |
result = [] | |
for features1, features2 in zip(pyramid1, pyramid2): | |
result.append(torch.cat([features1, features2], dim=1)) | |
return result | |
class Conv2d(nn.Sequential): | |
def __init__(self, in_channels, out_channels, size, activation: Optional[str] = 'relu'): | |
assert activation in (None, 'relu') | |
super().__init__( | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=size, | |
padding='same' if size % 2 else 0) | |
) | |
self.size = size | |
self.activation = nn.LeakyReLU(.2) if activation == 'relu' else None | |
def forward(self, x): | |
if not self.size % 2: | |
x = F.pad(x, (0, 1, 0, 1)) | |
y = self[0](x) | |
if self.activation is not None: | |
y = self.activation(y) | |
return y | |