Spaces:
Sleeping
Sleeping
File size: 1,733 Bytes
b2ffc9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# @title Load functions for working with image coordinates and labels
# @title Load utility functions for data loading and preprocessing
from typing import Tuple, Union
import torch
import warnings
warnings.filterwarnings("ignore", module="torchvision.datasets")
def to_onehot(idx: torch.Tensor, n: int) -> torch.Tensor:
"""
One-hot encoding of a label
"""
if torch.max(idx).item() >= n:
raise AssertionError(
"Labelling must start from 0 and "
"maximum label value must be less than total number of classes")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if idx.dim() == 1:
idx = idx.unsqueeze(1)
onehot = torch.zeros(idx.size(0), n, device=device)
return onehot.scatter_(1, idx.to(device), 1)
def grid2xy(X1: torch.Tensor, X2: torch.Tensor) -> torch.Tensor:
X = torch.cat((X1[None], X2[None]), 0)
d0, d1 = X.shape[0], X.shape[1] * X.shape[2]
X = X.reshape(d0, d1).T
return X
def imcoordgrid(im_dim: Tuple) -> torch.Tensor:
xx = torch.linspace(-1, 1, im_dim[0])
yy = torch.linspace(1, -1, im_dim[1])
x0, x1 = torch.meshgrid(xx, yy)
return grid2xy(x0, x1)
def transform_coordinates(coord: torch.Tensor,
phi: Union[torch.Tensor, float] = 0,
coord_dx: Union[torch.Tensor, float] = 0,
) -> torch.Tensor:
if torch.sum(phi) == 0:
phi = coord.new_zeros(coord.shape[0])
rotmat_r1 = torch.stack([torch.cos(phi), torch.sin(phi)], 1)
rotmat_r2 = torch.stack([-torch.sin(phi), torch.cos(phi)], 1)
rotmat = torch.stack([rotmat_r1, rotmat_r2], axis=1)
coord = torch.bmm(coord, rotmat)
return coord + coord_dx
|