import torch import torch.nn as nn from torchvision.ops import deform_conv2d class DeformableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False): super(DeformableConv2d, self).__init__() assert type(kernel_size) == tuple or type(kernel_size) == int kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) self.stride = stride if type(stride) == tuple else (stride, stride) self.padding = padding self.offset_conv = nn.Conv2d(in_channels, 2 * kernel_size[0] * kernel_size[1], kernel_size=kernel_size, stride=stride, padding=self.padding, bias=True) nn.init.constant_(self.offset_conv.weight, 0.) nn.init.constant_(self.offset_conv.bias, 0.) self.modulator_conv = nn.Conv2d(in_channels, 1 * kernel_size[0] * kernel_size[1], kernel_size=kernel_size, stride=stride, padding=self.padding, bias=True) nn.init.constant_(self.modulator_conv.weight, 0.) nn.init.constant_(self.modulator_conv.bias, 0.) self.regular_conv = nn.Conv2d(in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=self.padding, bias=bias) def forward(self, x): #h, w = x.shape[2:] #max_offset = max(h, w)/4. offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) modulator = 2. * torch.sigmoid(self.modulator_conv(x)) x = deform_conv2d( input=x, offset=offset, weight=self.regular_conv.weight, bias=self.regular_conv.bias, padding=self.padding, mask=modulator, stride=self.stride, ) return x