Spaces:
Build error
Build error
import torch | |
from torch.autograd import Variable | |
import numpy as np | |
import collections | |
__all__ = ['as_variable', 'as_numpy', 'mark_volatile'] | |
def as_variable(obj): | |
if isinstance(obj, Variable): | |
return obj | |
if isinstance(obj, collections.Sequence): | |
return [as_variable(v) for v in obj] | |
elif isinstance(obj, collections.Mapping): | |
return {k: as_variable(v) for k, v in obj.items()} | |
else: | |
return Variable(obj) | |
def as_numpy(obj): | |
if isinstance(obj, collections.Sequence): | |
return [as_numpy(v) for v in obj] | |
elif isinstance(obj, collections.Mapping): | |
return {k: as_numpy(v) for k, v in obj.items()} | |
elif isinstance(obj, Variable): | |
return obj.data.cpu().numpy() | |
elif torch.is_tensor(obj): | |
return obj.cpu().numpy() | |
else: | |
return np.array(obj) | |
def mark_volatile(obj): | |
if torch.is_tensor(obj): | |
obj = Variable(obj) | |
if isinstance(obj, Variable): | |
obj.no_grad = True | |
return obj | |
elif isinstance(obj, collections.Mapping): | |
return {k: mark_volatile(o) for k, o in obj.items()} | |
elif isinstance(obj, collections.Sequence): | |
return [mark_volatile(o) for o in obj] | |
else: | |
return obj | |