File size: 1,109 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from enum import Enum
from typing import Union
import torch
class Format(str, Enum):
NCHW = 'NCHW'
NHWC = 'NHWC'
NCL = 'NCL'
NLC = 'NLC'
FormatT = Union[str, Format]
def get_spatial_dim(fmt: FormatT):
fmt = Format(fmt)
if fmt is Format.NLC:
dim = (1,)
elif fmt is Format.NCL:
dim = (2,)
elif fmt is Format.NHWC:
dim = (1, 2)
else:
dim = (2, 3)
return dim
def get_channel_dim(fmt: FormatT):
fmt = Format(fmt)
if fmt is Format.NHWC:
dim = 3
elif fmt is Format.NLC:
dim = 2
else:
dim = 1
return dim
def nchw_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NHWC:
x = x.permute(0, 2, 3, 1)
elif fmt == Format.NLC:
x = x.flatten(2).transpose(1, 2)
elif fmt == Format.NCL:
x = x.flatten(2)
return x
def nhwc_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NCHW:
x = x.permute(0, 3, 1, 2)
elif fmt == Format.NLC:
x = x.flatten(1, 2)
elif fmt == Format.NCL:
x = x.flatten(1, 2).transpose(1, 2)
return x
|