Spaces:
Runtime error
Runtime error
| """PyTorch layer for extracting image features for the film_net interpolator. | |
| The feature extractor implemented here converts an image pyramid into a pyramid | |
| of deep features. The feature pyramid serves a similar purpose as U-Net | |
| architecture's encoder, but we use a special cascaded architecture described in | |
| Multi-view Image Fusion [1]. | |
| For comprehensiveness, below is a short description of the idea. While the | |
| description is a bit involved, the cascaded feature pyramid can be used just | |
| like any image feature pyramid. | |
| Why cascaded architeture? | |
| ========================= | |
| To understand the concept it is worth reviewing a traditional feature pyramid | |
| first: *A traditional feature pyramid* as in U-net or in many optical flow | |
| networks is built by alternating between convolutions and pooling, starting | |
| from the input image. | |
| It is well known that early features of such architecture correspond to low | |
| level concepts such as edges in the image whereas later layers extract | |
| semantically higher level concepts such as object classes etc. In other words, | |
| the meaning of the filters in each resolution level is different. For problems | |
| such as semantic segmentation and many others this is a desirable property. | |
| However, the asymmetric features preclude sharing weights across resolution | |
| levels in the feature extractor itself and in any subsequent neural networks | |
| that follow. This can be a downside, since optical flow prediction, for | |
| instance is symmetric across resolution levels. The cascaded feature | |
| architecture addresses this shortcoming. | |
| How is it built? | |
| ================ | |
| The *cascaded* feature pyramid contains feature vectors that have constant | |
| length and meaning on each resolution level, except few of the finest ones. The | |
| advantage of this is that the subsequent optical flow layer can learn | |
| synergically from many resolutions. This means that coarse level prediction can | |
| benefit from finer resolution training examples, which can be useful with | |
| moderately sized datasets to avoid overfitting. | |
| The cascaded feature pyramid is built by extracting shallower subtree pyramids, | |
| each one of them similar to the traditional architecture. Each subtree | |
| pyramid S_i is extracted starting from each resolution level: | |
| image resolution 0 -> S_0 | |
| image resolution 1 -> S_1 | |
| image resolution 2 -> S_2 | |
| ... | |
| If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid | |
| is constructed by concatenating features as follows (assuming subtree depth=3): | |
| lvl | |
| feat_0 = concat( S_0_0 ) | |
| feat_1 = concat( S_1_0 S_0_1 ) | |
| feat_2 = concat( S_2_0 S_1_1 S_0_2 ) | |
| feat_3 = concat( S_3_0 S_2_1 S_1_2 ) | |
| feat_4 = concat( S_4_0 S_3_1 S_2_2 ) | |
| feat_5 = concat( S_5_0 S_4_1 S_3_2 ) | |
| .... | |
| In above, all levels except feat_0 and feat_1 have the same number of features | |
| with similar semantic meaning. This enables training a single optical flow | |
| predictor module shared by levels 2,3,4,5... . For more details and evaluation | |
| see [1]. | |
| [1] Multi-view Image Fusion, Trinidad et al. 2019 | |
| """ | |
| from typing import List | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from util import Conv2d | |
| class SubTreeExtractor(nn.Module): | |
| """Extracts a hierarchical set of features from an image. | |
| This is a conventional, hierarchical image feature extractor, that extracts | |
| [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels. | |
| Each level is followed by average pooling. | |
| """ | |
| def __init__(self, in_channels=3, channels=64, n_layers=4): | |
| super().__init__() | |
| convs = [] | |
| for i in range(n_layers): | |
| convs.append(nn.Sequential( | |
| Conv2d(in_channels, (channels << i), 3), | |
| Conv2d((channels << i), (channels << i), 3) | |
| )) | |
| in_channels = channels << i | |
| self.convs = nn.ModuleList(convs) | |
| def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]: | |
| """Extracts a pyramid of features from the image. | |
| Args: | |
| image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS. | |
| n: number of pyramid levels to extract. This can be less or equal to | |
| options.sub_levels given in the __init__. | |
| Returns: | |
| The pyramid of features, starting from the finest level. Each element | |
| contains the output after the last convolution on the corresponding | |
| pyramid level. | |
| """ | |
| head = image | |
| pyramid = [] | |
| for i, layer in enumerate(self.convs): | |
| head = layer(head) | |
| pyramid.append(head) | |
| if i < n - 1: | |
| head = F.avg_pool2d(head, kernel_size=2, stride=2) | |
| return pyramid | |
| class FeatureExtractor(nn.Module): | |
| """Extracts features from an image pyramid using a cascaded architecture. | |
| """ | |
| def __init__(self, in_channels=3, channels=64, sub_levels=4): | |
| super().__init__() | |
| self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels) | |
| self.sub_levels = sub_levels | |
| def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: | |
| """Extracts a cascaded feature pyramid. | |
| Args: | |
| image_pyramid: Image pyramid as a list, starting from the finest level. | |
| Returns: | |
| A pyramid of cascaded features. | |
| """ | |
| sub_pyramids: List[List[torch.Tensor]] = [] | |
| for i in range(len(image_pyramid)): | |
| # At each level of the image pyramid, creates a sub_pyramid of features | |
| # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor. | |
| # We use the same instance since we want to share the weights. | |
| # | |
| # However, we cap the depth of the sub_pyramid so we don't create features | |
| # that are beyond the coarsest level of the cascaded feature pyramid we | |
| # want to generate. | |
| capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels) | |
| sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels)) | |
| # Below we generate the cascades of features on each level of the feature | |
| # pyramid. Assuming sub_levels=3, The layout of the features will be | |
| # as shown in the example on file documentation above. | |
| feature_pyramid: List[torch.Tensor] = [] | |
| for i in range(len(image_pyramid)): | |
| features = sub_pyramids[i][0] | |
| for j in range(1, self.sub_levels): | |
| if j <= i: | |
| features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) | |
| feature_pyramid.append(features) | |
| return feature_pyramid | |