File size: 1,504 Bytes
a3d6c18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
Source url: https://github.com/MarcoForte/FBA_Matting
License: MIT License
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
class Conv2d(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(Conv2d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
def forward(self, x):
# return super(Conv2d, self).forward(x)
weight = self.weight
weight_mean = (
weight.mean(dim=1, keepdim=True)
.mean(dim=2, keepdim=True)
.mean(dim=3, keepdim=True)
)
weight = weight - weight_mean
# std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
std = (
torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
-1, 1, 1, 1
)
+ 1e-5
)
weight = weight / std.expand_as(weight)
return F.conv2d(
x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
def BatchNorm2d(num_features):
return nn.GroupNorm(num_channels=num_features, num_groups=32)
|