nickovchinnikov's picture
Init
9d61c9b
raw
history blame contribute delete
645 Bytes
import torch
from torch import Tensor
def sample_wise_min_max(x: Tensor) -> Tensor:
r"""Applies sample-wise min-max normalization to a tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, num_samples, num_features).
Returns:
torch.Tensor: Normalized tensor of the same shape as the input tensor.
"""
# Compute the maximum and minimum values of each sample in the batch
maximum = torch.amax(x, dim=(1, 2), keepdim=True)
minimum = torch.amin(x, dim=(1, 2), keepdim=True)
# Apply sample-wise min-max normalization to the input tensor
return (x - minimum) / (maximum - minimum)