Spaces:
Runtime error
Runtime error
Update Time_TravelRephotography/torch_utils/ops/bias_act.py
Browse files
Time_TravelRephotography/torch_utils/ops/bias_act.py
CHANGED
@@ -88,6 +88,15 @@ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
|
|
88 |
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
89 |
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
#----------------------------------------------------------------------------
|
92 |
|
93 |
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
@@ -121,6 +130,7 @@ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=N
|
|
121 |
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
122 |
return x
|
123 |
|
|
|
124 |
#----------------------------------------------------------------------------
|
125 |
|
126 |
_bias_act_cuda_cache = dict()
|
|
|
88 |
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
89 |
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
90 |
|
91 |
+
class bias_act_relu(nn.Module):
|
92 |
+
def __init__(self, dim, negative_slope=0.2, scale=2 ** 0.5):
|
93 |
+
super().__init__()
|
94 |
+
self.bias = nn.Parameter(torch.zeros(dim))
|
95 |
+
self.negative_slope = negative_slope
|
96 |
+
self.scale = scale
|
97 |
+
|
98 |
+
def forward(self, input):
|
99 |
+
return bias_act(input, b=self.bias, self.negative_slope, self.scale)
|
100 |
#----------------------------------------------------------------------------
|
101 |
|
102 |
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
|
|
130 |
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
131 |
return x
|
132 |
|
133 |
+
|
134 |
#----------------------------------------------------------------------------
|
135 |
|
136 |
_bias_act_cuda_cache = dict()
|