maldv's picture
Upload folder using huggingface_hub
24e35df verified
raw
history blame contribute delete
No virus
6.45 kB
import torch
def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor:
"""
Merges two tensors using 2D Fourier transform interpolation.
Parameters:
- v0 (torch.Tensor): The first input tensor.
- v1 (torch.Tensor): The second input tensor.
- t (float): Interpolation parameter (0 <= t <= 1).
Returns:
- torch.Tensor: The tensor resulting from the interpolated inverse FFT.
"""
# Ensure the input tensors are on the same device and dtype
v0 = v0.to("cuda:0")
v1 = v1.to("cuda:0")
if len(v0.shape) == 1:
fft_v0 = torch.fft.fft(v0)
fft_v1 = torch.fft.fft(v1)
result_fft = torch.zeros_like(fft_v0)
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
del v0, v1, fft_v0, fft_v1, result_fft
return merged_tensor
# Perform the 2D FFT on both tensors
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
# Initialize the result FFT tensor
result_fft = torch.zeros_like(fft_v0)
# Compare real parts of the coefficients
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
# Create masks for where signs match and where they do not
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
# Assuming the imaginary part should be treated similarly, adjust this if not
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
# Perform the inverse FFT to go back to the spatial domain
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
return merged_tensor
def merge_tensors_fft_shell(v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
"""
Merges two tensors using 2D Fourier transform interpolation.
Parameters:
- v0 (torch.Tensor): The first input tensor.
- v1 (torch.Tensor): The second input tensor.
Returns:
- torch.Tensor: The tensor resulting from the maximal interpolated inverse FFT.
"""
# Ensure the input tensors are on the same device and dtype
v0 = v0.to("cuda:0")
v1 = v1.to("cuda:0")
if len(v0.shape) == 1:
fft_v0 = torch.fft.fft(v0)
fft_v1 = torch.fft.fft(v1)
result_fft = torch.zeros_like(fft_v0)
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask])
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
del v0, v1, fft_v0, fft_v1, result_fft
return merged_tensor
# Perform the 2D FFT on both tensors
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
# Initialize the result FFT tensor
result_fft = torch.zeros_like(fft_v0)
# Compare real parts of the coefficients
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
# Create masks for where signs match and where they do not
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
# Assuming the imaginary part should be treated similarly, adjust this if not
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask])
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
# Perform the inverse FFT to go back to the spatial domain
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
return merged_tensor