from typing import List, Union import torch from torch import Tensor, nn def pad_1D(inputs: List[Tensor], pad_value: float = 0.0) -> Tensor: r"""Pad a list of 1D tensor list to the same length. Args: inputs (List[torch.Tensor]): List of 1D numpy arrays to pad. pad_value (float): Value to use for padding. Default is 0.0. Returns: torch.Tensor: Padded 2D numpy array of shape (len(inputs), max_len), where max_len is the length of the longest input array. """ max_len = max(x.size(0) for x in inputs) padded_inputs = [nn.functional.pad(x, (0, max_len - x.size(0)), value=pad_value) for x in inputs] return torch.stack(padded_inputs) def pad_2D( inputs: List[Tensor], maxlen: Union[int, None] = None, pad_value: float = 0.0, ) -> Tensor: r"""Pad a list of 2D tensor arrays to the same length. Args: inputs (List[torch.Tensor]): List of 2D numpy arrays to pad. maxlen (Union[int, None]): Maximum length to pad the arrays to. If None, pad to the length of the longest array. Default is None. pad_value (float): Value to use for padding. Default is 0.0. Returns: torch.Tensor: Padded 3D numpy array of shape (len(inputs), max_len, input_dim), where max_len is the maximum length of the input arrays, and input_dim is the dimension of the input arrays. """ max_len = max(x.size(1) for x in inputs) if maxlen is None else maxlen padded_inputs = [nn.functional.pad(x, (0, max_len - x.size(1), 0, 0), value=pad_value) for x in inputs] return torch.stack(padded_inputs) def pad_3D(inputs: Union[Tensor, List[Tensor]], B: int, T: int, L: int) -> Tensor: r"""Pad a 3D torch tensor to a specified shape. Args: inputs (torch.Tensor): 3D numpy array to pad. B (int): Batch size to pad the array to. T (int): Time steps to pad the array to. L (int): Length to pad the array to. Returns: torch.Tensor: Padded 3D numpy array of shape (B, T, L), where B is the batch size, T is the time steps, and L is the length. """ if isinstance(inputs, list): inputs_padded = torch.zeros(B, T, L, dtype=inputs[0].dtype) for i, input_ in enumerate(inputs): inputs_padded[i, :input_.size(0), :input_.size(1)] = input_ elif isinstance(inputs, torch.Tensor): inputs_padded = torch.zeros(B, T, L, dtype=inputs.dtype) inputs_padded[:inputs.size(0), :inputs.size(1), :inputs.size(2)] = inputs return inputs_padded