File size: 645 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch import Tensor


def sample_wise_min_max(x: Tensor) -> Tensor:
    r"""Applies sample-wise min-max normalization to a tensor.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, num_samples, num_features).

    Returns:
        torch.Tensor: Normalized tensor of the same shape as the input tensor.
    """
    # Compute the maximum and minimum values of each sample in the batch
    maximum = torch.amax(x, dim=(1, 2), keepdim=True)
    minimum = torch.amin(x, dim=(1, 2), keepdim=True)

    # Apply sample-wise min-max normalization to the input tensor
    return (x - minimum) / (maximum - minimum)