File size: 2,439 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
""" Interpolation helpers for timm layers
RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations
Copyright Shane Barratt, Apache 2.0 license
"""
import torch
from itertools import product
class RegularGridInterpolator:
""" Interpolate data defined on a rectilinear grid with even or uneven spacing.
Produces similar results to scipy RegularGridInterpolator or interp2d
in 'linear' mode.
Taken from https://github.com/sbarratt/torch_interpolations
"""
def __init__(self, points, values):
self.points = points
self.values = values
assert isinstance(self.points, tuple) or isinstance(self.points, list)
assert isinstance(self.values, torch.Tensor)
self.ms = list(self.values.shape)
self.n = len(self.points)
assert len(self.ms) == self.n
for i, p in enumerate(self.points):
assert isinstance(p, torch.Tensor)
assert p.shape[0] == self.values.shape[i]
def __call__(self, points_to_interp):
assert self.points is not None
assert self.values is not None
assert len(points_to_interp) == len(self.points)
K = points_to_interp[0].shape[0]
for x in points_to_interp:
assert x.shape[0] == K
idxs = []
dists = []
overalls = []
for p, x in zip(self.points, points_to_interp):
idx_right = torch.bucketize(x, p)
idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
dist_left = x - p[idx_left]
dist_right = p[idx_right] - x
dist_left[dist_left < 0] = 0.
dist_right[dist_right < 0] = 0.
both_zero = (dist_left == 0) & (dist_right == 0)
dist_left[both_zero] = dist_right[both_zero] = 1.
idxs.append((idx_left, idx_right))
dists.append((dist_left, dist_right))
overalls.append(dist_left + dist_right)
numerator = 0.
for indexer in product([0, 1], repeat=self.n):
as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
numerator += self.values[as_s] * \
torch.prod(torch.stack(bs_s), dim=0)
denominator = torch.prod(torch.stack(overalls), dim=0)
return numerator / denominator
|