feng2022 commited on
Commit
3f34e8d
1 Parent(s): fe796c9

Update Time_TravelRephotography/op/upfirdn2d.py

Browse files
Time_TravelRephotography/op/upfirdn2d.py CHANGED
@@ -141,12 +141,43 @@ class UpFirDn2d(Function):
141
  return grad_input, None, None, None, None
142
 
143
 
144
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145
- out = UpFirDn2d.apply(
146
- input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
147
- )
148
-
149
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
 
152
  def upfirdn2d_native(
 
141
  return grad_input, None, None, None, None
142
 
143
 
144
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
145
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
146
+ Performs the following sequence of operations for each channel:
147
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
148
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
149
+ Negative padding corresponds to cropping the image.
150
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
151
+ so that the footprint of all output pixels lies within the input image.
152
+ 4. Downsample the image by keeping every Nth pixel (`down`).
153
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
154
+ The fused op is considerably more efficient than performing the same calculation
155
+ using standard PyTorch ops. It supports gradients of arbitrary order.
156
+ Args:
157
+ x: Float32/float64/float16 input tensor of the shape
158
+ `[batch_size, num_channels, in_height, in_width]`.
159
+ f: Float32 FIR filter of the shape
160
+ `[filter_height, filter_width]` (non-separable),
161
+ `[filter_taps]` (separable), or
162
+ `None` (identity).
163
+ up: Integer upsampling factor. Can be a single int or a list/tuple
164
+ `[x, y]` (default: 1).
165
+ down: Integer downsampling factor. Can be a single int or a list/tuple
166
+ `[x, y]` (default: 1).
167
+ padding: Padding with respect to the upsampled image. Can be a single number
168
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
169
+ (default: 0).
170
+ flip_filter: False = convolution, True = correlation (default: False).
171
+ gain: Overall scaling factor for signal magnitude (default: 1).
172
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
173
+ Returns:
174
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
175
+ """
176
+ assert isinstance(x, torch.Tensor)
177
+ assert impl in ['ref', 'cuda']
178
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
179
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
180
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
181
 
182
 
183
  def upfirdn2d_native(