Spaces:
Build error
Build error
import torch | |
from torch import nn | |
def knn(x, k): | |
inner = -2*torch.matmul(x.transpose(2, 1), x) | |
xx = torch.sum(x**2, dim=1, keepdim=True) | |
pairwise_distance = -xx - inner - xx.transpose(2, 1) | |
idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) | |
return idx, pairwise_distance | |
def local_operator(x, k): | |
batch_size = x.size(0) | |
num_points = x.size(2) | |
x = x.view(batch_size, -1, num_points) | |
idx, _ = knn(x, k=k) | |
device = torch.device('cpu') | |
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points | |
idx = idx + idx_base | |
idx = idx.view(-1) | |
_, num_dims, _ = x.size() | |
x = x.transpose(2, 1).contiguous() | |
neighbor = x.view(batch_size * num_points, -1)[idx, :] | |
neighbor = neighbor.view(batch_size, num_points, k, num_dims) | |
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) | |
feature = torch.cat((neighbor-x, neighbor), dim=3).permute(0, 3, 1, 2) # local and global all in | |
return feature | |
def local_operator_withnorm(x, norm_plt, k): | |
batch_size = x.size(0) | |
num_points = x.size(2) | |
x = x.view(batch_size, -1, num_points) | |
norm_plt = norm_plt.view(batch_size, -1, num_points) | |
idx, _ = knn(x, k=k) # (batch_size, num_points, k) | |
device = torch.device('cpu') | |
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points | |
idx = idx + idx_base | |
idx = idx.view(-1) | |
_, num_dims, _ = x.size() | |
x = x.transpose(2, 1).contiguous() | |
norm_plt = norm_plt.transpose(2, 1).contiguous() | |
neighbor = x.view(batch_size * num_points, -1)[idx, :] | |
neighbor_norm = norm_plt.view(batch_size * num_points, -1)[idx, :] | |
neighbor = neighbor.view(batch_size, num_points, k, num_dims) | |
neighbor_norm = neighbor_norm.view(batch_size, num_points, k, num_dims) | |
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) | |
feature = torch.cat((neighbor-x, neighbor, neighbor_norm), dim=3).permute(0, 3, 1, 2) # 3c | |
return feature | |
def GDM(x, M): | |
""" | |
Geometry-Disentangle Module | |
M: number of disentangled points in both sharp and gentle variation components | |
""" | |
k = 64 # number of neighbors to decide the range of j in Eq.(5) | |
tau = 0.2 # threshold in Eq.(2) | |
sigma = 2 # parameters of f (Gaussian function in Eq.(2)) | |
############### | |
"""Graph Construction:""" | |
device = torch.device('cpu') | |
batch_size = x.size(0) | |
num_points = x.size(2) | |
x = x.view(batch_size, -1, num_points) | |
idx, p = knn(x, k=k) # p: -[(x1-x2)^2+...] | |
# here we add a tau | |
p1 = torch.abs(p) | |
p1 = torch.sqrt(p1) | |
mask = p1 < tau | |
# here we add a sigma | |
p = p / (sigma * sigma) | |
w = torch.exp(p) # b,n,n | |
w = torch.mul(mask.float(), w) | |
b = 1/torch.sum(w, dim=1) | |
b = b.reshape(batch_size, num_points, 1).repeat(1, 1, num_points) | |
c = torch.eye(num_points, num_points, device=device) | |
c = c.expand(batch_size, num_points, num_points) | |
D = b * c # b,n,n | |
A = torch.matmul(D, w) # normalized adjacency matrix A_hat | |
# Get Aij in a local area: | |
idx2 = idx.view(batch_size * num_points, -1) | |
idx_base2 = torch.arange(0, batch_size * num_points, device=device).view(-1, 1) * num_points | |
idx2 = idx2 + idx_base2 | |
idx2 = idx2.reshape(batch_size * num_points, k)[:, 1:k] | |
idx2 = idx2.reshape(batch_size * num_points * (k - 1)) | |
idx2 = idx2.view(-1) | |
A = A.view(-1) | |
A = A[idx2].reshape(batch_size, num_points, k - 1) # Aij: b,n,k | |
############### | |
"""Disentangling Point Clouds into Sharp(xs) and Gentle(xg) Variation Components:""" | |
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points | |
idx = idx + idx_base | |
idx = idx.reshape(batch_size * num_points, k)[:, 1:k] | |
idx = idx.reshape(batch_size * num_points * (k - 1)) | |
_, num_dims, _ = x.size() | |
x = x.transpose(2, 1).contiguous() # b,n,c | |
neighbor = x.view(batch_size * num_points, -1)[idx, :] | |
neighbor = neighbor.view(batch_size, num_points, k - 1, num_dims) # b,n,k,c | |
A = A.reshape(batch_size, num_points, k - 1, 1) # b,n,k,1 | |
n = A.mul(neighbor) # b,n,k,c | |
n = torch.sum(n, dim=2) # b,n,c | |
pai = torch.norm(x - n, dim=-1).pow(2) # Eq.(5) | |
pais = pai.topk(k=M, dim=-1)[1] # first M points as the sharp variation component | |
paig = (-pai).topk(k=M, dim=-1)[1] # last M points as the gentle variation component | |
pai_base = torch.arange(0, batch_size, device=device).view(-1, 1) * num_points | |
indices = (pais + pai_base).view(-1) | |
indiceg = (paig + pai_base).view(-1) | |
xs = x.view(batch_size * num_points, -1)[indices, :] | |
xg = x.view(batch_size * num_points, -1)[indiceg, :] | |
xs = xs.view(batch_size, M, -1) # b,M,c | |
xg = xg.view(batch_size, M, -1) # b,M,c | |
return xs, xg | |
class SGCAM(nn.Module): | |
"""Sharp-Gentle Complementary Attention Module:""" | |
def __init__(self, in_channels, inter_channels=None, bn_layer=True): | |
super(SGCAM, self).__init__() | |
self.in_channels = in_channels | |
self.inter_channels = inter_channels | |
if self.inter_channels is None: | |
self.inter_channels = in_channels // 2 | |
if self.inter_channels == 0: | |
self.inter_channels = 1 | |
conv_nd = nn.Conv1d | |
bn = nn.BatchNorm1d | |
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | |
kernel_size=1, stride=1, padding=0) | |
if bn_layer: | |
self.W = nn.Sequential( | |
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | |
kernel_size=1, stride=1, padding=0), | |
bn(self.in_channels) | |
) | |
nn.init.constant(self.W[1].weight, 0) | |
nn.init.constant(self.W[1].bias, 0) | |
else: | |
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | |
kernel_size=1, stride=1, padding=0) | |
nn.init.constant(self.W.weight, 0) | |
nn.init.constant(self.W.bias, 0) | |
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | |
kernel_size=1, stride=1, padding=0) | |
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | |
kernel_size=1, stride=1, padding=0) | |
def forward(self, x, x_2): | |
batch_size = x.size(0) | |
g_x = self.g(x_2).view(batch_size, self.inter_channels, -1) | |
g_x = g_x.permute(0, 2, 1) | |
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | |
theta_x = theta_x.permute(0, 2, 1) | |
phi_x = self.phi(x_2).view(batch_size, self.inter_channels, -1) | |
W = torch.matmul(theta_x, phi_x) # Attention Matrix | |
N = W.size(-1) | |
W_div_C = W / N | |
y = torch.matmul(W_div_C, g_x) | |
y = y.permute(0, 2, 1).contiguous() | |
y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | |
W_y = self.W(y) | |
y = W_y + x | |
return y | |