|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from . import _Tensor, Tensor |
|
from .reference import _dims, _enable_layers, llist, ltuple |
|
|
|
|
|
class DelayedMulTensor(_Tensor): |
|
def __init__(self, lhs, rhs): |
|
self._lhs, self._rhs = lhs, rhs |
|
self._data = None |
|
self._levels_data = None |
|
self._has_device = lhs._has_device or rhs._has_device |
|
self._batchtensor_data = None |
|
self._tensor_data = None |
|
|
|
@property |
|
def _levels(self): |
|
if self._levels_data is None: |
|
levels = llist(self._lhs._levels) |
|
for l in self._rhs._levels: |
|
if l not in levels: |
|
levels.append(l) |
|
self._levels_data = ltuple(levels) |
|
return self._levels_data |
|
|
|
@property |
|
def _batchtensor(self): |
|
if self._batchtensor_data is None: |
|
with _enable_layers(self._levels): |
|
print("bt multiply fallback") |
|
self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor |
|
return self._batchtensor_data |
|
|
|
@property |
|
def _tensor(self): |
|
if self._tensor_data is None: |
|
self._tensor_data = Tensor.from_batched( |
|
self._batchtensor, self._has_device |
|
)._tensor |
|
return self._tensor_data |
|
|
|
@property |
|
def ndim(self): |
|
return self._batchtensor.ndim |
|
|
|
@property |
|
def dims(self): |
|
return ltuple(super().dims) |
|
|
|
def sum(self, dim): |
|
dims = _dims(dim, 0, False, False) |
|
n = ord("a") |
|
all_levels = self._levels |
|
|
|
def to_char(d): |
|
return chr(n + all_levels.index(d)) |
|
|
|
plhs, levelslhs = self._lhs._tensor, self._lhs._levels |
|
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels |
|
new_dims = tuple(d for d in self.dims if d not in dims) |
|
new_levels = [l for l in self._levels if l not in dims] |
|
fmt = "".join( |
|
[ |
|
*(to_char(d) for d in levelslhs), |
|
",", |
|
*(to_char(d) for d in levelsrhs), |
|
"->", |
|
*(to_char(d) for d in new_levels), |
|
] |
|
) |
|
result_data = torch.einsum(fmt, (plhs, prhs)) |
|
return Tensor.from_positional(result_data, new_levels, True) |
|
|