Spaces:
Runtime error
Runtime error
Update Time_TravelRephotography/op/upfirdn2d.py
Browse files
Time_TravelRephotography/op/upfirdn2d.py
CHANGED
@@ -141,12 +141,43 @@ class UpFirDn2d(Function):
|
|
141 |
return grad_input, None, None, None, None
|
142 |
|
143 |
|
144 |
-
def upfirdn2d(
|
145 |
-
|
146 |
-
|
147 |
-
)
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
|
152 |
def upfirdn2d_native(
|
|
|
141 |
return grad_input, None, None, None, None
|
142 |
|
143 |
|
144 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
145 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
146 |
+
Performs the following sequence of operations for each channel:
|
147 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
148 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
149 |
+
Negative padding corresponds to cropping the image.
|
150 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
151 |
+
so that the footprint of all output pixels lies within the input image.
|
152 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
153 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
154 |
+
The fused op is considerably more efficient than performing the same calculation
|
155 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
156 |
+
Args:
|
157 |
+
x: Float32/float64/float16 input tensor of the shape
|
158 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
159 |
+
f: Float32 FIR filter of the shape
|
160 |
+
`[filter_height, filter_width]` (non-separable),
|
161 |
+
`[filter_taps]` (separable), or
|
162 |
+
`None` (identity).
|
163 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
164 |
+
`[x, y]` (default: 1).
|
165 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
166 |
+
`[x, y]` (default: 1).
|
167 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
168 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
169 |
+
(default: 0).
|
170 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
171 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
172 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
173 |
+
Returns:
|
174 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
175 |
+
"""
|
176 |
+
assert isinstance(x, torch.Tensor)
|
177 |
+
assert impl in ['ref', 'cuda']
|
178 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
179 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
180 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
181 |
|
182 |
|
183 |
def upfirdn2d_native(
|