Apollo / look2hear /utils /torch_utils.py
Serhiy Stetskovych
Initial code
78e32cc
raw
history blame
1.43 kB
###
# Author: Kai Li
# Date: 2021-06-18 17:29:21
# LastEditors: Kai Li
# LastEditTime: 2021-06-21 23:52:52
###
import torch
import torch.nn as nn
def pad_x_to_y(x, y, axis: int = -1):
if axis != -1:
raise NotImplementedError
inp_len = y.shape[axis]
output_len = x.shape[axis]
return nn.functional.pad(x, [0, inp_len - output_len])
def shape_reconstructed(reconstructed, size):
if len(size) == 1:
return reconstructed.squeeze(0)
return reconstructed
def tensors_to_device(tensors, device):
"""Transfer tensor, dict or list of tensors to device.
Args:
tensors (:class:`torch.Tensor`): May be a single, a list or a
dictionary of tensors.
device (:class: `torch.device`): the device where to place the tensors.
Returns:
Union [:class:`torch.Tensor`, list, tuple, dict]:
Same as input but transferred to device.
Goes through lists and dicts and transfers the torch.Tensor to
device. Leaves the rest untouched.
"""
if isinstance(tensors, torch.Tensor):
return tensors.to(device)
elif isinstance(tensors, (list, tuple)):
return [tensors_to_device(tens, device) for tens in tensors]
elif isinstance(tensors, dict):
for key in tensors.keys():
tensors[key] = tensors_to_device(tensors[key], device)
return tensors
else:
return tensors