# https://gist.github.com/xmodar/ae2d94681a6fda39f3c4f3ac91eef7b7 # %% import torch def sinusoidal(positions, features=16, periods=10000): """Encode `positions` using sinusoidal positional encoding Args: positions: tensor of positions features: half the number of features per position periods: used frequencies for the sinusoidal functions Returns: Positional encoding of shape `(*positions.shape, features, 2)` """ dtype = positions.dtype if positions.is_floating_point() else None kwargs = dict(device=positions.device, dtype=dtype) omega = torch.logspace(0, 1 / features - 1, features, periods, **kwargs) fraction = omega * positions.unsqueeze(-1) return torch.stack((fraction.sin(), fraction.cos()), dim=-1) def point_pe(points, low=0, high=1, steps=100, features=16, periods=10000): """Encode points in bounded space using sinusoidal positional encoding Args: points: tensor of points; typically of shape (*, C) low: lower bound of the space; typically of shape (C,) high: upper bound of the space; typically of shape (C,) steps: number of cells that split the space; typically of shape (C,) features: half the number of features per position periods: used frequencies for the sinusoidal functions Returns: Positional encoded points of the following shape: `(*points.shape[:-1], points.shape[-1] * features * 2)` """ positions = (points - low).mul_(steps / (high - low)) return sinusoidal(positions, features, periods).flatten(-3) def point_position_encoding(points, max_steps=100, features=16, periods=10000): low = points.min(0).values high = points.max(0).values steps = high - low steps *= max_steps / steps.max() pe = point_pe(points, low, high, steps, features, periods) return pe def test(num_points=1000, max_steps=100, features=32, periods=10000): """Test point_pe""" point_cloud = torch.rand(num_points, 3) low = point_cloud.min(0).values high = point_cloud.max(0).values steps = high - low steps *= max_steps / steps.max() # print(point_pe(point_cloud, low, high, steps).shape) pe = point_pe(point_cloud, low, high, steps, features=features, periods=periods) return pe # %% if __name__ == "__main__": pe = test(20, 1000, periods=10000) import matplotlib.pyplot as plt fig = plt.figure(figsize=(10, 10)) plt.imshow(pe) # %% def pe_2d(num_points=14, max_steps=100, features=32, periods=10000): x = torch.linspace(0, 1, num_points) y = torch.linspace(0, 1, num_points) points = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2) # print(points) # print(points.shape) low = points.min(0).values high = points.max(0).values steps = high - low steps *= max_steps / steps.max() # print(point_pe(point_cloud, low, high, steps).shape) pe = point_pe(points, low, high, steps, features=features, periods=periods) pe = pe.reshape(num_points, num_points, -1) pe = pe.permute(2, 0, 1) return pe # %% if __name__ == "__main__": pe = pe_2d(3, max_steps=1000, periods=10000, features=32) import matplotlib.pyplot as plt fig = plt.figure(figsize=(10, 10)) plt.imshow(pe[64, :, :]) # %%