from functools import partial | |
import torch | |
def multi_apply(func, *args, **kwargs): | |
pfunc = partial(func, **kwargs) if kwargs else func | |
map_results = map(pfunc, *args) | |
return tuple(map(list, zip(*map_results))) | |
def torch_to_numpy(x): | |
assert isinstance(x, torch.Tensor) | |
return x.detach().cpu().numpy() | |