from enum import Enum from typing import Union import torch class Format(str, Enum): NCHW = 'NCHW' NHWC = 'NHWC' NCL = 'NCL' NLC = 'NLC' FormatT = Union[str, Format] def get_spatial_dim(fmt: FormatT): fmt = Format(fmt) if fmt is Format.NLC: dim = (1,) elif fmt is Format.NCL: dim = (2,) elif fmt is Format.NHWC: dim = (1, 2) else: dim = (2, 3) return dim def get_channel_dim(fmt: FormatT): fmt = Format(fmt) if fmt is Format.NHWC: dim = 3 elif fmt is Format.NLC: dim = 2 else: dim = 1 return dim def nchw_to(x: torch.Tensor, fmt: Format): if fmt == Format.NHWC: x = x.permute(0, 2, 3, 1) elif fmt == Format.NLC: x = x.flatten(2).transpose(1, 2) elif fmt == Format.NCL: x = x.flatten(2) return x def nhwc_to(x: torch.Tensor, fmt: Format): if fmt == Format.NCHW: x = x.permute(0, 3, 1, 2) elif fmt == Format.NLC: x = x.flatten(1, 2) elif fmt == Format.NCL: x = x.flatten(1, 2).transpose(1, 2) return x