feng2022 commited on
Commit
57d3312
1 Parent(s): d3ed5fc

Update Time_TravelRephotography/op/upfirdn2d.py

Browse files
Time_TravelRephotography/op/upfirdn2d.py CHANGED
@@ -6,13 +6,13 @@ from torch.utils.cpp_extension import load
6
 
7
 
8
  module_path = os.path.dirname(__file__)
9
- upfirdn2d_op = load(
10
- 'upfirdn2d',
11
- sources=[
12
- os.path.join(module_path, 'upfirdn2d.cpp'),
13
- os.path.join(module_path, 'upfirdn2d_kernel.cu'),
14
- ],
15
- )
16
 
17
 
18
  class UpFirDn2dBackward(Function):
@@ -140,6 +140,47 @@ class UpFirDn2d(Function):
140
 
141
  return grad_input, None, None, None, None
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
145
  r"""Pad, upsample, filter, and downsample a batch of 2D images.
 
6
 
7
 
8
  module_path = os.path.dirname(__file__)
9
+ #upfirdn2d_op = load(
10
+ # 'upfirdn2d',
11
+ # sources=[
12
+ # os.path.join(module_path, 'upfirdn2d.cpp'),
13
+ # os.path.join(module_path, 'upfirdn2d_kernel.cu'),
14
+ # ],
15
+ #)
16
 
17
 
18
  class UpFirDn2dBackward(Function):
 
140
 
141
  return grad_input, None, None, None, None
142
 
143
+ @misc.profiled_function
144
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
145
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
146
+ """
147
+ # Validate arguments.
148
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
149
+ if f is None:
150
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
151
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
152
+ assert f.dtype == torch.float32 and not f.requires_grad
153
+ batch_size, num_channels, in_height, in_width = x.shape
154
+ upx, upy = _parse_scaling(up)
155
+ downx, downy = _parse_scaling(down)
156
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
157
+
158
+ # Upsample by inserting zeros.
159
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
160
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
161
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
162
+
163
+ # Pad or crop.
164
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
165
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
166
+
167
+ # Setup filter.
168
+ f = f * (gain ** (f.ndim / 2))
169
+ f = f.to(x.dtype)
170
+ if not flip_filter:
171
+ f = f.flip(list(range(f.ndim)))
172
+
173
+ # Convolve with the filter.
174
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
175
+ if f.ndim == 4:
176
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
177
+ else:
178
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
179
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
180
+
181
+ # Downsample by throwing away pixels.
182
+ x = x[:, :, ::downy, ::downx]
183
+ return x
184
 
185
  def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
186
  r"""Pad, upsample, filter, and downsample a batch of 2D images.