|
|
|
|
|
|
|
|
|
|
|
import dis |
|
import inspect |
|
|
|
from dataclasses import dataclass |
|
from typing import Union |
|
|
|
from . import DimList |
|
|
|
_vmap_levels = [] |
|
|
|
|
|
@dataclass |
|
class LevelInfo: |
|
level: int |
|
alive: bool = True |
|
|
|
|
|
class Dim: |
|
def __init__(self, name: str, size: Union[None, int] = None): |
|
self.name = name |
|
self._size = None |
|
self._vmap_level = None |
|
if size is not None: |
|
self.size = size |
|
|
|
def __del__(self): |
|
if self._vmap_level is not None: |
|
_vmap_active_levels[self._vmap_stack].alive = False |
|
while ( |
|
not _vmap_levels[-1].alive |
|
and current_level() == _vmap_levels[-1].level |
|
): |
|
_vmap_decrement_nesting() |
|
_vmap_levels.pop() |
|
|
|
@property |
|
def size(self): |
|
assert self.is_bound |
|
return self._size |
|
|
|
@size.setter |
|
def size(self, size: int): |
|
from . import DimensionBindError |
|
|
|
if self._size is None: |
|
self._size = size |
|
self._vmap_level = _vmap_increment_nesting(size, "same") |
|
self._vmap_stack = len(_vmap_levels) |
|
_vmap_levels.append(LevelInfo(self._vmap_level)) |
|
|
|
elif self._size != size: |
|
raise DimensionBindError( |
|
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" |
|
) |
|
|
|
@property |
|
def is_bound(self): |
|
return self._size is not None |
|
|
|
def __repr__(self): |
|
return self.name |
|
|
|
|
|
def extract_name(inst): |
|
assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" |
|
return inst.argval |
|
|
|
|
|
_cache = {} |
|
|
|
|
|
def dims(lists=0): |
|
frame = inspect.currentframe() |
|
assert frame is not None |
|
calling_frame = frame.f_back |
|
assert calling_frame is not None |
|
code, lasti = calling_frame.f_code, calling_frame.f_lasti |
|
key = (code, lasti) |
|
if key not in _cache: |
|
first = lasti // 2 + 1 |
|
instructions = list(dis.get_instructions(calling_frame.f_code)) |
|
unpack = instructions[first] |
|
|
|
if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": |
|
|
|
name = unpack.argval |
|
ctor = Dim if lists == 0 else DimList |
|
_cache[key] = lambda: ctor(name=name) |
|
else: |
|
assert unpack.opname == "UNPACK_SEQUENCE" |
|
ndims = unpack.argval |
|
names = tuple( |
|
extract_name(instructions[first + 1 + i]) for i in range(ndims) |
|
) |
|
first_list = len(names) - lists |
|
_cache[key] = lambda: tuple( |
|
Dim(n) if i < first_list else DimList(name=n) |
|
for i, n in enumerate(names) |
|
) |
|
return _cache[key]() |
|
|
|
|
|
def _dim_set(positional, arg): |
|
def convert(a): |
|
if isinstance(a, Dim): |
|
return a |
|
else: |
|
assert isinstance(a, int) |
|
return positional[a] |
|
|
|
if arg is None: |
|
return positional |
|
elif not isinstance(arg, (Dim, int)): |
|
return tuple(convert(a) for a in arg) |
|
else: |
|
return (convert(arg),) |
|
|