feng2022 commited on
Commit
6daddb5
1 Parent(s): 8ea91d1

Update Time_TravelRephotography/torch_utils/ops/bias_act.py

Browse files
Time_TravelRephotography/torch_utils/ops/bias_act.py CHANGED
@@ -88,6 +88,15 @@ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
88
  return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89
  return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90
 
 
 
 
 
 
 
 
 
 
91
  #----------------------------------------------------------------------------
92
 
93
  def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
@@ -121,6 +130,7 @@ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=N
121
  x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
122
  return x
123
 
 
124
  #----------------------------------------------------------------------------
125
 
126
  _bias_act_cuda_cache = dict()
 
88
  return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89
  return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90
 
91
+ class bias_act_relu(nn.Module):
92
+ def __init__(self, dim, negative_slope=0.2, scale=2 ** 0.5):
93
+ super().__init__()
94
+ self.bias = nn.Parameter(torch.zeros(dim))
95
+ self.negative_slope = negative_slope
96
+ self.scale = scale
97
+
98
+ def forward(self, input):
99
+ return bias_act(input, b=self.bias, self.negative_slope, self.scale)
100
  #----------------------------------------------------------------------------
101
 
102
  def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
 
130
  x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
131
  return x
132
 
133
+
134
  #----------------------------------------------------------------------------
135
 
136
  _bias_act_cuda_cache = dict()