"""Beamformer module.""" from typing import Sequence, Tuple, Union import torch from packaging.version import parse as V from torch_complex import functional as FC from torch_complex.tensor import ComplexTensor EPS = torch.finfo(torch.double).eps is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0") is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") def new_complex_like( ref: Union[torch.Tensor, ComplexTensor], real_imag: Tuple[torch.Tensor, torch.Tensor], ): if isinstance(ref, ComplexTensor): return ComplexTensor(*real_imag) elif is_torch_complex_tensor(ref): return torch.complex(*real_imag) else: raise ValueError( "Please update your PyTorch version to 1.9+ for complex support." ) def is_torch_complex_tensor(c): return ( not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c) ) def is_complex(c): return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c) def to_double(c): if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c): return c.to(dtype=torch.complex128) else: return c.double() def to_float(c): if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c): return c.to(dtype=torch.complex64) else: return c.float() def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs): if not isinstance(seq, (list, tuple)): raise TypeError( "cat(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor" ) if isinstance(seq[0], ComplexTensor): return FC.cat(seq, *args, **kwargs) else: return torch.cat(seq, *args, **kwargs) def complex_norm( c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False ) -> torch.Tensor: if not is_complex(c): raise TypeError("Input is not a complex tensor.") if is_torch_complex_tensor(c): return torch.norm(c, dim=dim, keepdim=keepdim) else: if dim is None: return torch.sqrt((c.real**2 + c.imag**2).sum() + EPS) else: return torch.sqrt( (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS ) def einsum(equation, *operands): # NOTE: Do not mix ComplexTensor and torch.complex in the input! # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support # mixed input with complex and real tensors. if len(operands) == 1: if isinstance(operands[0], (tuple, list)): operands = operands[0] complex_module = FC if isinstance(operands[0], ComplexTensor) else torch return complex_module.einsum(equation, *operands) elif len(operands) != 2: op0 = operands[0] same_type = all(op.dtype == op0.dtype for op in operands[1:]) if same_type: _einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum return _einsum(equation, *operands) else: raise ValueError("0 or More than 2 operands are not supported.") a, b = operands if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): return FC.einsum(equation, a, b) elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): if not torch.is_complex(a): o_real = torch.einsum(equation, a, b.real) o_imag = torch.einsum(equation, a, b.imag) return torch.complex(o_real, o_imag) elif not torch.is_complex(b): o_real = torch.einsum(equation, a.real, b) o_imag = torch.einsum(equation, a.imag, b) return torch.complex(o_real, o_imag) else: return torch.einsum(equation, a, b) else: return torch.einsum(equation, a, b) def inverse( c: Union[torch.Tensor, ComplexTensor] ) -> Union[torch.Tensor, ComplexTensor]: if isinstance(c, ComplexTensor): return c.inverse2() else: return c.inverse() def matmul( a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor] ) -> Union[torch.Tensor, ComplexTensor]: # NOTE: Do not mix ComplexTensor and torch.complex in the input! # NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support # multiplication between complex and real tensors. if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): return FC.matmul(a, b) elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): if not torch.is_complex(a): o_real = torch.matmul(a, b.real) o_imag = torch.matmul(a, b.imag) return torch.complex(o_real, o_imag) elif not torch.is_complex(b): o_real = torch.matmul(a.real, b) o_imag = torch.matmul(a.imag, b) return torch.complex(o_real, o_imag) else: return torch.matmul(a, b) else: return torch.matmul(a, b) def trace(a: Union[torch.Tensor, ComplexTensor]): # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not # support bacth processing. Use FC.trace() as fallback. return FC.trace(a) def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0): if isinstance(a, ComplexTensor): return FC.reverse(a, dim=dim) else: return torch.flip(a, dims=(dim,)) def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]): """Solve the linear equation ax = b.""" # NOTE: Do not mix ComplexTensor and torch.complex in the input! # NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support # mixed input with complex and real tensors. if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor): if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): return FC.solve(b, a, return_LU=False) else: return matmul(inverse(a), b) elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)): if torch.is_complex(a) and torch.is_complex(b): return torch.linalg.solve(a, b) else: return matmul(inverse(a), b) else: if is_torch_1_8_plus: return torch.linalg.solve(a, b) else: return torch.solve(b, a)[0] def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs): if not isinstance(seq, (list, tuple)): raise TypeError( "stack(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor" ) if isinstance(seq[0], ComplexTensor): return FC.stack(seq, *args, **kwargs) else: return torch.stack(seq, *args, **kwargs)