Spaces:
Sleeping
Sleeping
File size: 6,750 Bytes
78e32cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
"""Beamformer module."""
from typing import Sequence, Tuple, Union
import torch
from packaging.version import parse as V
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
EPS = torch.finfo(torch.double).eps
is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0")
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
def new_complex_like(
ref: Union[torch.Tensor, ComplexTensor],
real_imag: Tuple[torch.Tensor, torch.Tensor],
):
if isinstance(ref, ComplexTensor):
return ComplexTensor(*real_imag)
elif is_torch_complex_tensor(ref):
return torch.complex(*real_imag)
else:
raise ValueError(
"Please update your PyTorch version to 1.9+ for complex support."
)
def is_torch_complex_tensor(c):
return (
not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
)
def is_complex(c):
return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
def to_double(c):
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
return c.to(dtype=torch.complex128)
else:
return c.double()
def to_float(c):
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
return c.to(dtype=torch.complex64)
else:
return c.float()
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
if not isinstance(seq, (list, tuple)):
raise TypeError(
"cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
"not Tensor"
)
if isinstance(seq[0], ComplexTensor):
return FC.cat(seq, *args, **kwargs)
else:
return torch.cat(seq, *args, **kwargs)
def complex_norm(
c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
) -> torch.Tensor:
if not is_complex(c):
raise TypeError("Input is not a complex tensor.")
if is_torch_complex_tensor(c):
return torch.norm(c, dim=dim, keepdim=keepdim)
else:
if dim is None:
return torch.sqrt((c.real**2 + c.imag**2).sum() + EPS)
else:
return torch.sqrt(
(c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
)
def einsum(equation, *operands):
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
# mixed input with complex and real tensors.
if len(operands) == 1:
if isinstance(operands[0], (tuple, list)):
operands = operands[0]
complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
return complex_module.einsum(equation, *operands)
elif len(operands) != 2:
op0 = operands[0]
same_type = all(op.dtype == op0.dtype for op in operands[1:])
if same_type:
_einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
return _einsum(equation, *operands)
else:
raise ValueError("0 or More than 2 operands are not supported.")
a, b = operands
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
return FC.einsum(equation, a, b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if not torch.is_complex(a):
o_real = torch.einsum(equation, a, b.real)
o_imag = torch.einsum(equation, a, b.imag)
return torch.complex(o_real, o_imag)
elif not torch.is_complex(b):
o_real = torch.einsum(equation, a.real, b)
o_imag = torch.einsum(equation, a.imag, b)
return torch.complex(o_real, o_imag)
else:
return torch.einsum(equation, a, b)
else:
return torch.einsum(equation, a, b)
def inverse(
c: Union[torch.Tensor, ComplexTensor]
) -> Union[torch.Tensor, ComplexTensor]:
if isinstance(c, ComplexTensor):
return c.inverse2()
else:
return c.inverse()
def matmul(
a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
) -> Union[torch.Tensor, ComplexTensor]:
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
# multiplication between complex and real tensors.
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
return FC.matmul(a, b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if not torch.is_complex(a):
o_real = torch.matmul(a, b.real)
o_imag = torch.matmul(a, b.imag)
return torch.complex(o_real, o_imag)
elif not torch.is_complex(b):
o_real = torch.matmul(a.real, b)
o_imag = torch.matmul(a.imag, b)
return torch.complex(o_real, o_imag)
else:
return torch.matmul(a, b)
else:
return torch.matmul(a, b)
def trace(a: Union[torch.Tensor, ComplexTensor]):
# NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
# support bacth processing. Use FC.trace() as fallback.
return FC.trace(a)
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
if isinstance(a, ComplexTensor):
return FC.reverse(a, dim=dim)
else:
return torch.flip(a, dims=(dim,))
def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
"""Solve the linear equation ax = b."""
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
# mixed input with complex and real tensors.
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
return FC.solve(b, a, return_LU=False)
else:
return matmul(inverse(a), b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if torch.is_complex(a) and torch.is_complex(b):
return torch.linalg.solve(a, b)
else:
return matmul(inverse(a), b)
else:
if is_torch_1_8_plus:
return torch.linalg.solve(a, b)
else:
return torch.solve(b, a)[0]
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
if not isinstance(seq, (list, tuple)):
raise TypeError(
"stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
"not Tensor"
)
if isinstance(seq[0], ComplexTensor):
return FC.stack(seq, *args, **kwargs)
else:
return torch.stack(seq, *args, **kwargs) |