|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import warnings |
|
import torch |
|
import torch.nn as nn |
|
from utils.seg_opr.conv_2_5d import Conv2_5D_depth, Conv2_5D_disp |
|
|
|
|
|
def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, |
|
**kwargs): |
|
for name, m in feature.named_modules(): |
|
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
|
conv_init(m.weight, **kwargs) |
|
elif isinstance(m, Conv2_5D_depth): |
|
conv_init(m.weight_0, **kwargs) |
|
conv_init(m.weight_1, **kwargs) |
|
conv_init(m.weight_2, **kwargs) |
|
elif isinstance(m, Conv2_5D_disp): |
|
conv_init(m.weight_0, **kwargs) |
|
conv_init(m.weight_1, **kwargs) |
|
conv_init(m.weight_2, **kwargs) |
|
elif isinstance(m, norm_layer): |
|
m.eps = bn_eps |
|
m.momentum = bn_momentum |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, |
|
**kwargs): |
|
if isinstance(module_list, list): |
|
for feature in module_list: |
|
__init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, |
|
**kwargs) |
|
else: |
|
__init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, |
|
**kwargs) |
|
|
|
|
|
def group_weight(weight_group, module, norm_layer, lr): |
|
group_decay = [] |
|
group_no_decay = [] |
|
for m in module.modules(): |
|
if isinstance(m, nn.Linear): |
|
group_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): |
|
group_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, Conv2_5D_depth): |
|
group_decay.append(m.weight_0) |
|
group_decay.append(m.weight_1) |
|
group_decay.append(m.weight_2) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, Conv2_5D_disp): |
|
group_decay.append(m.weight_0) |
|
group_decay.append(m.weight_1) |
|
group_decay.append(m.weight_2) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ |
|
or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): |
|
if m.weight is not None: |
|
group_no_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, nn.Parameter): |
|
group_decay.append(m) |
|
elif isinstance(m, nn.Embedding): |
|
group_decay.append(m) |
|
|
|
|
|
|
|
|
|
|
|
assert len(list(module.parameters())) == len(group_decay) + len( |
|
group_no_decay) |
|
weight_group.append(dict(params=group_decay, lr=lr)) |
|
weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) |
|
return weight_group |
|
|
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
def norm_cdf(x): |
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
return tensor |
|
|
|
|
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
|
r"""Fills the input Tensor with values drawn from a truncated |
|
normal distribution. The values are effectively drawn from the |
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
|
with values outside :math:`[a, b]` redrawn until they are within |
|
the bounds. The method used for generating the random values works |
|
best when :math:`a \leq \text{mean} \leq b`. |
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
mean: the mean of the normal distribution |
|
std: the standard deviation of the normal distribution |
|
a: the minimum cutoff value |
|
b: the maximum cutoff value |
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.trunc_normal_(w) |
|
""" |
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|
|