import torch def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor: """ Merges two tensors using 2D Fourier transform interpolation. Parameters: - v0 (torch.Tensor): The first input tensor. - v1 (torch.Tensor): The second input tensor. - t (float): Interpolation parameter (0 <= t <= 1). Returns: - torch.Tensor: The tensor resulting from the interpolated inverse FFT. """ # Ensure the input tensors are on the same device and dtype v0 = v0.to("cuda:0") v1 = v1.to("cuda:0") if len(v0.shape) == 1: fft_v0 = torch.fft.fft(v0) fft_v1 = torch.fft.fft(v1) result_fft = torch.zeros_like(fft_v0) real_v0 = fft_v0.real real_v1 = fft_v1.real abs_real_v0 = real_v0.abs() abs_real_v1 = real_v1.abs() sign_mask = real_v0.sign() == real_v1.sign() larger_values_mask = abs_real_v0 > abs_real_v1 result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask]) result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) imag_v0 = fft_v0.imag imag_v1 = fft_v1.imag abs_imag_v0 = imag_v0.abs() abs_imag_v1 = imag_v1.abs() larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part del v0, v1, fft_v0, fft_v1, result_fft return merged_tensor # Perform the 2D FFT on both tensors fft_v0 = torch.fft.fftn(v0, dim=(-2, -1)) fft_v1 = torch.fft.fftn(v1, dim=(-2, -1)) # Initialize the result FFT tensor result_fft = torch.zeros_like(fft_v0) # Compare real parts of the coefficients real_v0 = fft_v0.real real_v1 = fft_v1.real abs_real_v0 = real_v0.abs() abs_real_v1 = real_v1.abs() # Create masks for where signs match and where they do not sign_mask = real_v0.sign() == real_v1.sign() larger_values_mask = abs_real_v0 > abs_real_v1 # Where signs match, interpolate; where signs do not match, take the larger by magnitude result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask]) result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) # Assuming the imaginary part should be treated similarly, adjust this if not imag_v0 = fft_v0.imag imag_v1 = fft_v1.imag abs_imag_v0 = imag_v0.abs() abs_imag_v1 = imag_v1.abs() larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) # Perform the inverse FFT to go back to the spatial domain merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part return merged_tensor def merge_tensors_fft_shell(v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor: """ Merges two tensors using 2D Fourier transform interpolation. Parameters: - v0 (torch.Tensor): The first input tensor. - v1 (torch.Tensor): The second input tensor. Returns: - torch.Tensor: The tensor resulting from the maximal interpolated inverse FFT. """ # Ensure the input tensors are on the same device and dtype v0 = v0.to("cuda:0") v1 = v1.to("cuda:0") if len(v0.shape) == 1: fft_v0 = torch.fft.fft(v0) fft_v1 = torch.fft.fft(v1) result_fft = torch.zeros_like(fft_v0) real_v0 = fft_v0.real real_v1 = fft_v1.real abs_real_v0 = real_v0.abs() abs_real_v1 = real_v1.abs() sign_mask = real_v0.sign() == real_v1.sign() larger_values_mask = abs_real_v0 > abs_real_v1 result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask]) result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) imag_v0 = fft_v0.imag imag_v1 = fft_v1.imag abs_imag_v0 = imag_v0.abs() abs_imag_v1 = imag_v1.abs() larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask]) result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part del v0, v1, fft_v0, fft_v1, result_fft return merged_tensor # Perform the 2D FFT on both tensors fft_v0 = torch.fft.fftn(v0, dim=(-2, -1)) fft_v1 = torch.fft.fftn(v1, dim=(-2, -1)) # Initialize the result FFT tensor result_fft = torch.zeros_like(fft_v0) # Compare real parts of the coefficients real_v0 = fft_v0.real real_v1 = fft_v1.real abs_real_v0 = real_v0.abs() abs_real_v1 = real_v1.abs() # Create masks for where signs match and where they do not sign_mask = real_v0.sign() == real_v1.sign() larger_values_mask = abs_real_v0 > abs_real_v1 # Where signs match, interpolate; where signs do not match, take the larger by magnitude result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask]) result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) # Assuming the imaginary part should be treated similarly, adjust this if not imag_v0 = fft_v0.imag imag_v1 = fft_v1.imag abs_imag_v0 = imag_v0.abs() abs_imag_v1 = imag_v1.abs() larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask]) result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) # Perform the inverse FFT to go back to the spatial domain merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part return merged_tensor