|
""" Interpolation helpers for timm layers |
|
|
|
RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations |
|
Copyright Shane Barratt, Apache 2.0 license |
|
""" |
|
import torch |
|
from itertools import product |
|
|
|
|
|
class RegularGridInterpolator: |
|
""" Interpolate data defined on a rectilinear grid with even or uneven spacing. |
|
Produces similar results to scipy RegularGridInterpolator or interp2d |
|
in 'linear' mode. |
|
|
|
Taken from https://github.com/sbarratt/torch_interpolations |
|
""" |
|
|
|
def __init__(self, points, values): |
|
self.points = points |
|
self.values = values |
|
|
|
assert isinstance(self.points, tuple) or isinstance(self.points, list) |
|
assert isinstance(self.values, torch.Tensor) |
|
|
|
self.ms = list(self.values.shape) |
|
self.n = len(self.points) |
|
|
|
assert len(self.ms) == self.n |
|
|
|
for i, p in enumerate(self.points): |
|
assert isinstance(p, torch.Tensor) |
|
assert p.shape[0] == self.values.shape[i] |
|
|
|
def __call__(self, points_to_interp): |
|
assert self.points is not None |
|
assert self.values is not None |
|
|
|
assert len(points_to_interp) == len(self.points) |
|
K = points_to_interp[0].shape[0] |
|
for x in points_to_interp: |
|
assert x.shape[0] == K |
|
|
|
idxs = [] |
|
dists = [] |
|
overalls = [] |
|
for p, x in zip(self.points, points_to_interp): |
|
idx_right = torch.bucketize(x, p) |
|
idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 |
|
idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) |
|
dist_left = x - p[idx_left] |
|
dist_right = p[idx_right] - x |
|
dist_left[dist_left < 0] = 0. |
|
dist_right[dist_right < 0] = 0. |
|
both_zero = (dist_left == 0) & (dist_right == 0) |
|
dist_left[both_zero] = dist_right[both_zero] = 1. |
|
|
|
idxs.append((idx_left, idx_right)) |
|
dists.append((dist_left, dist_right)) |
|
overalls.append(dist_left + dist_right) |
|
|
|
numerator = 0. |
|
for indexer in product([0, 1], repeat=self.n): |
|
as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] |
|
bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] |
|
numerator += self.values[as_s] * \ |
|
torch.prod(torch.stack(bs_s), dim=0) |
|
denominator = torch.prod(torch.stack(overalls), dim=0) |
|
return numerator / denominator |
|
|