Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import abc | |
import numpy as np | |
import torch | |
def cast_tensor_type(inputs, src_type, dst_type): | |
"""Recursively convert Tensor in inputs from src_type to dst_type. | |
Args: | |
inputs: Inputs that to be casted. | |
src_type (torch.dtype): Source type. | |
dst_type (torch.dtype): Destination type. | |
Returns: | |
The same type with inputs, but all contained Tensors have been cast. | |
""" | |
if isinstance(inputs, torch.Tensor): | |
return inputs.to(dst_type) | |
elif isinstance(inputs, str): | |
return inputs | |
elif isinstance(inputs, np.ndarray): | |
return inputs | |
elif isinstance(inputs, abc.Mapping): | |
return type(inputs)({ | |
k: cast_tensor_type(v, src_type, dst_type) | |
for k, v in inputs.items() | |
}) | |
elif isinstance(inputs, abc.Iterable): | |
return type(inputs)( | |
cast_tensor_type(item, src_type, dst_type) for item in inputs) | |
return inputs | |