HaMeR / mmpose /core /fp16 /utils.py
geopavlakos's picture
Initial commit
d7a991a
raw
history blame
1.02 kB
# Copyright (c) OpenMMLab. All rights reserved.
from collections import abc
import numpy as np
import torch
def cast_tensor_type(inputs, src_type, dst_type):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type.
dst_type (torch.dtype): Destination type.
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
if isinstance(inputs, torch.Tensor):
return inputs.to(dst_type)
elif isinstance(inputs, str):
return inputs
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)(
cast_tensor_type(item, src_type, dst_type) for item in inputs)
return inputs