|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from .BilateralCorrelation_NN import bilateralcorrelation_nn as bicorr_nn |
|
|
|
|
|
def resize(x, scale_factor): |
|
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) |
|
|
|
|
|
def bilinear_sampler(img, coords, mask=False): |
|
""" Wrapper for grid_sample, uses pixel coordinates """ |
|
H, W = img.shape[-2:] |
|
xgrid, ygrid = coords.split([1,1], dim=-1) |
|
xgrid = 2*xgrid/(W-1) - 1 |
|
ygrid = 2*ygrid/(H-1) - 1 |
|
|
|
grid = torch.cat([xgrid, ygrid], dim=-1) |
|
img = F.grid_sample(img, grid, align_corners=True) |
|
|
|
if mask: |
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) |
|
return img, mask.float() |
|
|
|
return img |
|
|
|
|
|
def coords_grid(batch, ht, wd, device): |
|
coords = torch.meshgrid(torch.arange(ht, device=device), |
|
torch.arange(wd, device=device), |
|
indexing='ij') |
|
coords = torch.stack(coords[::-1], dim=0).float() |
|
return coords[None].repeat(batch, 1, 1, 1) |
|
|
|
|
|
class SmallUpdateBlock(nn.Module): |
|
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, |
|
corr_levels=4, radius=3, scale_factor=None): |
|
super(SmallUpdateBlock, self).__init__() |
|
cor_planes = corr_levels * (2 * radius + 1) **2 |
|
self.scale_factor = scale_factor |
|
|
|
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) |
|
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) |
|
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) |
|
self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) |
|
|
|
self.gru = nn.Sequential( |
|
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), |
|
nn.LeakyReLU(negative_slope=0.1, inplace=True), |
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), |
|
) |
|
|
|
self.feat_head = nn.Sequential( |
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), |
|
nn.LeakyReLU(negative_slope=0.1, inplace=True), |
|
nn.Conv2d(hidden_dim, cdim, 3, padding=1), |
|
) |
|
|
|
self.flow_head = nn.Sequential( |
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), |
|
nn.LeakyReLU(negative_slope=0.1, inplace=True), |
|
nn.Conv2d(hidden_dim, 4, 3, padding=1), |
|
) |
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) |
|
|
|
def forward(self, net, flow, corr): |
|
net = resize(net, 1 / self.scale_factor |
|
) if self.scale_factor is not None else net |
|
cor = self.lrelu(self.convc1(corr)) |
|
flo = self.lrelu(self.convf1(flow)) |
|
flo = self.lrelu(self.convf2(flo)) |
|
cor_flo = torch.cat([cor, flo], dim=1) |
|
inp = self.lrelu(self.conv(cor_flo)) |
|
inp = torch.cat([inp, flow, net], dim=1) |
|
|
|
out = self.gru(inp) |
|
delta_net = self.feat_head(out) |
|
delta_flow = self.flow_head(out) |
|
|
|
if self.scale_factor is not None: |
|
delta_net = resize(delta_net, scale_factor=self.scale_factor) |
|
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) |
|
|
|
return delta_net, delta_flow |
|
|
|
|
|
class BasicUpdateBlock(nn.Module): |
|
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, |
|
fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): |
|
super(BasicUpdateBlock, self).__init__() |
|
cor_planes = (2 * radius + 1) ** 2 * corr_levels |
|
|
|
self.scale_factor = scale_factor |
|
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) |
|
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) |
|
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) |
|
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) |
|
self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) |
|
|
|
self.gru = nn.Sequential( |
|
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), |
|
nn.LeakyReLU(negative_slope=0.1, inplace=True), |
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), |
|
) |
|
|
|
self.feat_head = nn.Sequential( |
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), |
|
nn.LeakyReLU(negative_slope=0.1, inplace=True), |
|
nn.Conv2d(hidden_dim, cdim, 3, padding=1), |
|
) |
|
|
|
self.flow_head = nn.Sequential( |
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), |
|
nn.LeakyReLU(negative_slope=0.1, inplace=True), |
|
nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), |
|
) |
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) |
|
|
|
def forward(self, net, flow, corr): |
|
net = resize(net, 1 / self.scale_factor |
|
) if self.scale_factor is not None else net |
|
cor = self.lrelu(self.convc1(corr)) |
|
cor = self.lrelu(self.convc2(cor)) |
|
flo = self.lrelu(self.convf1(flow)) |
|
flo = self.lrelu(self.convf2(flo)) |
|
cor_flo = torch.cat([cor, flo], dim=1) |
|
inp = self.lrelu(self.conv(cor_flo)) |
|
inp = torch.cat([inp, flow, net], dim=1) |
|
|
|
out = self.gru(inp) |
|
delta_net = self.feat_head(out) |
|
delta_flow = self.flow_head(out) |
|
|
|
if self.scale_factor is not None: |
|
delta_net = resize(delta_net, scale_factor=self.scale_factor) |
|
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) |
|
return delta_net, delta_flow |
|
|
|
|
|
class BidirCorrBlock: |
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
|
self.num_levels = num_levels |
|
self.radius = radius |
|
|
|
self.fmap1_pyramid = [fmap1] |
|
self.fmap2_pyramid = [fmap2] |
|
|
|
for _ in range(self.num_levels - 1): |
|
fmap1 = F.avg_pool2d(fmap1, 2, stride=2) |
|
fmap2 = F.avg_pool2d(fmap2, 2, stride=2) |
|
self.fmap1_pyramid.append(fmap1) |
|
self.fmap2_pyramid.append(fmap2) |
|
|
|
def __call__(self, flowt0, flowt1, time_step): |
|
r = self.radius |
|
|
|
out_pyramid = [] |
|
out_pyramid_T = [] |
|
flowt0 = flowt0.contiguous() |
|
flowt1 = flowt1.contiguous() |
|
for i in range(self.num_levels): |
|
fmap1 = self.fmap1_pyramid[i] |
|
fmap2 = self.fmap2_pyramid[i] |
|
corr0 = bicorr_nn.apply(fmap2, fmap1, flowt0, time_step, self.radius) |
|
corr1 = bicorr_nn.apply(fmap1, fmap2, flowt1, time_step, self.radius) |
|
out_pyramid.append(corr0) |
|
out_pyramid_T.append(corr1) |
|
|
|
out = torch.cat(out_pyramid, dim=1) |
|
out_T = torch.cat(out_pyramid_T, dim=1) |
|
return out.contiguous().float(), out_T.contiguous().float() |
|
|
|
@staticmethod |
|
def corr(fmap1, fmap2): |
|
batch, dim, ht, wd = fmap1.shape |
|
fmap1 = fmap1.view(batch, dim, ht*wd) |
|
fmap2 = fmap2.view(batch, dim, ht*wd) |
|
|
|
corr = torch.matmul(fmap1.transpose(1,2), fmap2) |
|
corr = corr.view(batch, ht, wd, 1, ht, wd) |
|
return corr / torch.sqrt(torch.tensor(dim).float()) |