diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..9b8829a9e7bc9ce8d844bf8b664644d44ade0d91
Binary files /dev/null and b/.DS_Store differ
diff --git a/.gitattributes b/.gitattributes
index bf07816c74bac9b682df196e02c6482e474e9b52..25c40ae1b353ff70100a9929e6e1b388254871a1 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zstandard filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.t7 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b5c73459097a0576b3460d075d7973b78934eb1d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+*.DS_store
+*.pyc
+flagged
+*.png
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..13566b81b018ad684f3a35fee301741b2734c8f4
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..4c30cfa0bdc2dab927736e4f6233197d1d86a430
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..7f9fc7525ea767213d302133c9e5d4e4cae0714d
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/pointcloud-c.iml b/.idea/pointcloud-c.iml
new file mode 100644
index 0000000000000000000000000000000000000000..d6c15b62fd5c869153d436f9d521cda16baeeeaa
--- /dev/null
+++ b/.idea/pointcloud-c.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/DGCNN.py b/DGCNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..46a0d46ff0fa2be2b57d165dfe5049b4db36d8c7
--- /dev/null
+++ b/DGCNN.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Author: Yue Wang
+@Contact: yuewangx@mit.edu
+@File: model.py
+@Time: 2018/10/13 6:35 PM
+"""
+
+import os
+import sys
+import copy
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def knn(x, k):
+ inner = -2 * torch.matmul(x.transpose(2, 1), x)
+ xx = torch.sum(x ** 2, dim=1, keepdim=True)
+ pairwise_distance = -xx - inner - xx.transpose(2, 1)
+
+ idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
+ return idx
+
+
+def get_graph_feature(x, k=20, idx=None):
+ batch_size = x.size(0)
+ num_points = x.size(2)
+ x = x.view(batch_size, -1, num_points)
+ if idx is None:
+ idx = knn(x, k=k) # (batch_size, num_points, k)
+ device = torch.device('cpu')
+
+ idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
+
+ idx = idx + idx_base
+
+ idx = idx.view(-1)
+
+ _, num_dims, _ = x.size()
+
+ x = x.transpose(2,
+ 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
+ feature = x.view(batch_size * num_points, -1)[idx, :]
+ feature = feature.view(batch_size, num_points, k, num_dims)
+ x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+
+ feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous()
+
+ return feature
+
+class DGCNN(nn.Module):
+ def __init__(self, output_channels=40):
+ super(DGCNN, self).__init__()
+ self.k = 20
+ emb_dims = 1024
+ dropout = 0.5
+
+ self.bn1 = nn.BatchNorm2d(64)
+ self.bn2 = nn.BatchNorm2d(64)
+ self.bn3 = nn.BatchNorm2d(128)
+ self.bn4 = nn.BatchNorm2d(256)
+ self.bn5 = nn.BatchNorm1d(emb_dims)
+
+ self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
+ self.bn1,
+ nn.LeakyReLU(negative_slope=0.2))
+ self.conv2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
+ self.bn2,
+ nn.LeakyReLU(negative_slope=0.2))
+ self.conv3 = nn.Sequential(nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False),
+ self.bn3,
+ nn.LeakyReLU(negative_slope=0.2))
+ self.conv4 = nn.Sequential(nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False),
+ self.bn4,
+ nn.LeakyReLU(negative_slope=0.2))
+ self.conv5 = nn.Sequential(nn.Conv1d(512, emb_dims, kernel_size=1, bias=False),
+ self.bn5,
+ nn.LeakyReLU(negative_slope=0.2))
+ self.linear1 = nn.Linear(emb_dims * 2, 512, bias=False)
+ self.bn6 = nn.BatchNorm1d(512)
+ self.dp1 = nn.Dropout(p=dropout)
+ self.linear2 = nn.Linear(512, 256)
+ self.bn7 = nn.BatchNorm1d(256)
+ self.dp2 = nn.Dropout(p=dropout)
+ self.linear3 = nn.Linear(256, output_channels)
+
+ def forward(self, x):
+ batch_size = x.size(0)
+ x = get_graph_feature(x, k=self.k)
+ x = self.conv1(x)
+ x1 = x.max(dim=-1, keepdim=False)[0]
+
+ x = get_graph_feature(x1, k=self.k)
+ x = self.conv2(x)
+ x2 = x.max(dim=-1, keepdim=False)[0]
+
+ x = get_graph_feature(x2, k=self.k)
+ x = self.conv3(x)
+ x3 = x.max(dim=-1, keepdim=False)[0]
+
+ x = get_graph_feature(x3, k=self.k)
+ x = self.conv4(x)
+ x4 = x.max(dim=-1, keepdim=False)[0]
+
+ x = torch.cat((x1, x2, x3, x4), dim=1)
+
+ x = self.conv5(x)
+ x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
+ x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
+ x = torch.cat((x1, x2), 1)
+
+ x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
+ x = self.dp1(x)
+ x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
+ x = self.dp2(x)
+ x = self.linear3(x)
+ return x
\ No newline at end of file
diff --git a/GDANet_WOLFMix.t7 b/GDANet_WOLFMix.t7
new file mode 100644
index 0000000000000000000000000000000000000000..c2e1db5d2a854175680e7aca194130ae4f6a26f1
--- /dev/null
+++ b/GDANet_WOLFMix.t7
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ef1f05156c6ace4f72e9e70ac373dd7f5d8ece8fe2af15a1099c56b8e13431dd
+size 3796397
diff --git a/GDANet_cls.py b/GDANet_cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e70725ccd0d3c8c62f98b8c5be5f55bde4168cb
--- /dev/null
+++ b/GDANet_cls.py
@@ -0,0 +1,113 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+from util.GDANet_util import local_operator, GDM, SGCAM
+
+
+class GDANET(nn.Module):
+ def __init__(self):
+ super(GDANET, self).__init__()
+
+ self.bn1 = nn.BatchNorm2d(64, momentum=0.1)
+ self.bn11 = nn.BatchNorm2d(64, momentum=0.1)
+ self.bn12 = nn.BatchNorm1d(64, momentum=0.1)
+
+ self.bn2 = nn.BatchNorm2d(64, momentum=0.1)
+ self.bn21 = nn.BatchNorm2d(64, momentum=0.1)
+ self.bn22 = nn.BatchNorm1d(64, momentum=0.1)
+
+ self.bn3 = nn.BatchNorm2d(128, momentum=0.1)
+ self.bn31 = nn.BatchNorm2d(128, momentum=0.1)
+ self.bn32 = nn.BatchNorm1d(128, momentum=0.1)
+
+ self.bn4 = nn.BatchNorm1d(512, momentum=0.1)
+
+ self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=True),
+ self.bn1)
+ self.conv11 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=True),
+ self.bn11)
+ self.conv12 = nn.Sequential(nn.Conv1d(64 * 2, 64, kernel_size=1, bias=True),
+ self.bn12)
+
+ self.conv2 = nn.Sequential(nn.Conv2d(67 * 2, 64, kernel_size=1, bias=True),
+ self.bn2)
+ self.conv21 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=True),
+ self.bn21)
+ self.conv22 = nn.Sequential(nn.Conv1d(64 * 2, 64, kernel_size=1, bias=True),
+ self.bn22)
+
+ self.conv3 = nn.Sequential(nn.Conv2d(131 * 2, 128, kernel_size=1, bias=True),
+ self.bn3)
+ self.conv31 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=1, bias=True),
+ self.bn31)
+ self.conv32 = nn.Sequential(nn.Conv1d(128, 128, kernel_size=1, bias=True),
+ self.bn32)
+
+ self.conv4 = nn.Sequential(nn.Conv1d(256, 512, kernel_size=1, bias=True),
+ self.bn4)
+
+ self.SGCAM_1s = SGCAM(64)
+ self.SGCAM_1g = SGCAM(64)
+ self.SGCAM_2s = SGCAM(64)
+ self.SGCAM_2g = SGCAM(64)
+
+ self.linear1 = nn.Linear(1024, 512, bias=True)
+ self.bn6 = nn.BatchNorm1d(512)
+ self.dp1 = nn.Dropout(p=0.4)
+ self.linear2 = nn.Linear(512, 256, bias=True)
+ self.bn7 = nn.BatchNorm1d(256)
+ self.dp2 = nn.Dropout(p=0.4)
+ self.linear3 = nn.Linear(256, 40, bias=True)
+
+ def forward(self, x):
+ B, C, N = x.size()
+ ###############
+ """block 1"""
+ # Local operator:
+ x1 = local_operator(x, k=30)
+ x1 = F.relu(self.conv1(x1))
+ x1 = F.relu(self.conv11(x1))
+ x1 = x1.max(dim=-1, keepdim=False)[0]
+
+ # Geometry-Disentangle Module:
+ x1s, x1g = GDM(x1, M=256)
+
+ # Sharp-Gentle Complementary Attention Module:
+ y1s = self.SGCAM_1s(x1, x1s.transpose(2, 1))
+ y1g = self.SGCAM_1g(x1, x1g.transpose(2, 1))
+ z1 = torch.cat([y1s, y1g], 1)
+ z1 = F.relu(self.conv12(z1))
+ ###############
+ """block 2"""
+ x1t = torch.cat((x, z1), dim=1)
+ x2 = local_operator(x1t, k=30)
+ x2 = F.relu(self.conv2(x2))
+ x2 = F.relu(self.conv21(x2))
+ x2 = x2.max(dim=-1, keepdim=False)[0]
+
+ x2s, x2g = GDM(x2, M=256)
+
+ y2s = self.SGCAM_2s(x2, x2s.transpose(2, 1))
+ y2g = self.SGCAM_2g(x2, x2g.transpose(2, 1))
+ z2 = torch.cat([y2s, y2g], 1)
+ z2 = F.relu(self.conv22(z2))
+ ###############
+ x2t = torch.cat((x1t, z2), dim=1)
+ x3 = local_operator(x2t, k=30)
+ x3 = F.relu(self.conv3(x3))
+ x3 = F.relu(self.conv31(x3))
+ x3 = x3.max(dim=-1, keepdim=False)[0]
+ z3 = F.relu(self.conv32(x3))
+ ###############
+ x = torch.cat((z1, z2, z3), dim=1)
+ x = F.relu(self.conv4(x))
+ x11 = F.adaptive_max_pool1d(x, 1).view(B, -1)
+ x22 = F.adaptive_avg_pool1d(x, 1).view(B, -1)
+ x = torch.cat((x11, x22), 1)
+
+ x = F.relu(self.bn6(self.linear1(x)))
+ x = self.dp1(x)
+ x = F.relu(self.bn7(self.linear2(x)))
+ x = self.dp2(x)
+ x = self.linear3(x)
+ return x
diff --git a/__pycache__/DGCNN.cpython-38.pyc b/__pycache__/DGCNN.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2f4d9fa9eace01bd14971abaaf07f550f6a7c1e
Binary files /dev/null and b/__pycache__/DGCNN.cpython-38.pyc differ
diff --git a/__pycache__/GDANet_cls.cpython-38.pyc b/__pycache__/GDANet_cls.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3e238be932a056b22801eb2a80a18e67033b089
Binary files /dev/null and b/__pycache__/GDANet_cls.cpython-38.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..25f074f9ee9a3c5e7e593584eb448d4e7109be17
--- /dev/null
+++ b/app.py
@@ -0,0 +1,119 @@
+import gradio as gr
+import mathutils
+import math
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+import matplotlib.cm as cmx
+import os.path as osp
+import h5py
+import random
+import torch
+import torch.nn as nn
+
+from GDANet_cls import GDANET
+from DGCNN import DGCNN
+
+with open('shape_names.txt') as f:
+ CLASS_NAME = f.read().splitlines()
+
+model_gda = GDANET()
+model_gda = nn.DataParallel(model_gda)
+model_gda.load_state_dict(torch.load('./GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
+model_gda.eval()
+
+model_dgcnn = DGCNN()
+model_dgcnn = nn.DataParallel(model_dgcnn)
+model_dgcnn.load_state_dict(torch.load('./dgcnn.t7', map_location=torch.device('cpu')))
+model_dgcnn.eval()
+
+def pyplot_draw_point_cloud(points, corruption):
+ rot1 = mathutils.Euler([-math.pi / 2, 0, 0]).to_matrix().to_3x3()
+ rot2 = mathutils.Euler([0, 0, math.pi]).to_matrix().to_3x3()
+ points = np.dot(points, rot1)
+ points = np.dot(points, rot2)
+ x, y, z = points[:, 0], points[:, 1], points[:, 2]
+ colorsMap = 'winter'
+ cs = y
+ cm = plt.get_cmap(colorsMap)
+ cNorm = matplotlib.colors.Normalize(vmin=-1, vmax=1)
+ scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
+ fig = plt.figure(figsize=(5, 5))
+ ax = fig.add_subplot(111, projection='3d')
+ ax.scatter(x, y, z, c=scalarMap.to_rgba(cs))
+ scalarMap.set_array(cs)
+ ax.set_xlim(-1, 1)
+ ax.set_ylim(-1, 1)
+ ax.set_zlim(-1, 1)
+ plt.axis('off')
+ plt.title(corruption, fontsize=30)
+ plt.tight_layout()
+ plt.savefig('visualization.png', bbox_inches='tight', dpi=200)
+ plt.close()
+
+
+
+def load_dataset(corruption_idx, severity):
+ corruptions = [
+ 'clean',
+ 'scale',
+ 'jitter',
+ 'rotate',
+ 'dropout_global',
+ 'dropout_local',
+ 'add_global',
+ 'add_local',
+ ]
+ corruption_type = corruptions[corruption_idx]
+ if corruption_type == 'clean':
+ f = h5py.File(osp.join('modelnet_c', corruption_type + '.h5'))
+ else:
+ f = h5py.File(osp.join('modelnet_c', corruption_type + '_{}'.format(severity-1) + '.h5'))
+ data = f['data'][:].astype('float32')
+ label = f['label'][:].astype('int64')
+ f.close()
+ return data, label
+
+def recognize_pcd(model, pcd):
+ pcd = torch.tensor(pcd).unsqueeze(0)
+ pcd = pcd.permute(0, 2, 1)
+ output = model(pcd)
+ prediction = output.softmax(-1).flatten()
+ _, top5_idx = torch.topk(prediction, 5)
+ return {CLASS_NAME[i]: float(prediction[i]) for i in top5_idx.tolist()}
+
+def run(seed, corruption_idx, severity):
+ data, label = load_dataset(corruption_idx, severity)
+ sample_indx = int(seed)
+ pcd, cls = data[sample_indx], label[sample_indx]
+ pyplot_draw_point_cloud(pcd, CLASS_NAME[cls[0]])
+ output = 'visualization.png'
+ return output, recognize_pcd(model_dgcnn, pcd), recognize_pcd(model_gda, pcd)
+
+if __name__ == '__main__':
+ iface = gr.Interface(
+ fn=run,
+ inputs=[
+ gr.components.Number(label='Sample Seed', precision=0),
+ gr.components.Radio(
+ ['Clean', 'Scale', 'Jitter', 'Rotate', 'Drop Global', 'Drop Local', 'Add Global', 'Add Local'],
+ value='Clean', type="index", label='Corruption Type'),
+ gr.components.Slider(1, 5, step=1, label='Corruption severity'),
+ ],
+ outputs=[
+ gr.components.Image(type="file", label="Visualization"),
+ gr.components.Label(num_top_classes=5, label="Baseline (DGCNN) Prediction"),
+ gr.components.Label(num_top_classes=5, label="Ours (GDANet+WolfMix) Prediction")
+ ],
+ live=False,
+ allow_flagging='never',
+ title="Benchmarking and Analyzing Point Cloud Classification under Corruptions [ICML 2022]",
+ description="Welcome to the demo of ModelNet-C! You can visualize various types of corrupted point clouds in ModelNet-C and see how our proposed techniques contribute to robust predicitions compared to baseline methods.",
+ examples=[
+ [0, 'Jitter', 5],
+ [999, 'Drop Local', 5],
+ ],
+ # css=".output-image, .image-preview {height: 500px !important}",
+ article="
ModelNet-C @ GitHub
"
+ )
+ iface.launch()
diff --git a/dgcnn.t7 b/dgcnn.t7
new file mode 100644
index 0000000000000000000000000000000000000000..a6f538c13bda5097e2447776cec8e3e6a7e5575d
--- /dev/null
+++ b/dgcnn.t7
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f366f60ca9dacf42cff5b747ba86020f10a6480ab31bb8122a8a609152ce4baa
+size 7268024
diff --git a/downsample.py b/downsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..7163168dba82fb7cc448d642f5d55339d869bc05
--- /dev/null
+++ b/downsample.py
@@ -0,0 +1,12 @@
+import glob
+import h5py
+
+for fpath in glob.glob('modelnet_c/*.h5'):
+ f = h5py.File(fpath)
+ data = f['data'][:].astype('float32')
+ label = f['label'][:].astype('int64')
+ f.close()
+ f = h5py.File(fpath, 'w')
+ f.create_dataset('data', data=data[:100])
+ f.create_dataset('label', data=label[:100])
+ f.close()
\ No newline at end of file
diff --git a/modelnet_c/.DS_Store b/modelnet_c/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/modelnet_c/.DS_Store differ
diff --git a/modelnet_c/add_global_0.h5 b/modelnet_c/add_global_0.h5
new file mode 100644
index 0000000000000000000000000000000000000000..09a5d6fa8ba7493fc47698d6bd733ec49c19438b
--- /dev/null
+++ b/modelnet_c/add_global_0.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0588ab009c0d0b4f4f8598e2fd7a6df0df14937b30a42ecdfc6c2811e70d494e
+size 61267680
diff --git a/modelnet_c/add_global_1.h5 b/modelnet_c/add_global_1.h5
new file mode 100644
index 0000000000000000000000000000000000000000..c7db180b541c79f555974def9e1029004bd19911
--- /dev/null
+++ b/modelnet_c/add_global_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e18f34cf1e30b58758edfbe20f4e0db0a53e067bc87520e20ee8eb9fc85035d
+size 61860000
diff --git a/modelnet_c/add_global_2.h5 b/modelnet_c/add_global_2.h5
new file mode 100644
index 0000000000000000000000000000000000000000..da1e67e50ce88fddecf861916a7c903b7dc4f1bd
--- /dev/null
+++ b/modelnet_c/add_global_2.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c01495ed18ede2c5d9bab1cfc29555b9153e862c827909295504f76909addef
+size 62452320
diff --git a/modelnet_c/add_global_3.h5 b/modelnet_c/add_global_3.h5
new file mode 100644
index 0000000000000000000000000000000000000000..ef6d6e6ac8c7b5044425f9d19b10d453192c410f
--- /dev/null
+++ b/modelnet_c/add_global_3.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8d5e4594ca8aeb84e10397e0189b45a594f47009a4669d4efe070ffe72d4c4c
+size 63044640
diff --git a/modelnet_c/add_global_4.h5 b/modelnet_c/add_global_4.h5
new file mode 100644
index 0000000000000000000000000000000000000000..e9df45a414ff82c09a19f6ba2791c2e39e1028ed
--- /dev/null
+++ b/modelnet_c/add_global_4.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:989951ee7a7056f70c03f671a61a7ac55ca3d2b78eeccfeb855fcc033b276797
+size 63636960
diff --git a/modelnet_c/add_local_0.h5 b/modelnet_c/add_local_0.h5
new file mode 100644
index 0000000000000000000000000000000000000000..2441b2e59d7ee08b27bafeb9d525f862901b2535
--- /dev/null
+++ b/modelnet_c/add_local_0.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd0e3adfcc8c98136b33f0a712033340e594cfcd7dafa3f76e6a8f517a59a429
+size 33310176
diff --git a/modelnet_c/add_local_1.h5 b/modelnet_c/add_local_1.h5
new file mode 100644
index 0000000000000000000000000000000000000000..15591771aef11fe056a73ea39bee52e551725c0e
--- /dev/null
+++ b/modelnet_c/add_local_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:43986bd483fdcc5f4a8779c31f3b45d45401b82d18443161872d0862106a0fa6
+size 36271776
diff --git a/modelnet_c/add_local_2.h5 b/modelnet_c/add_local_2.h5
new file mode 100644
index 0000000000000000000000000000000000000000..fad1b487dbb35a029a55cf8354070f263eba3680
--- /dev/null
+++ b/modelnet_c/add_local_2.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d69968a6f58f25687058eaaca1a67a6890396cac812529e9529b978054e94c36
+size 39233376
diff --git a/modelnet_c/add_local_3.h5 b/modelnet_c/add_local_3.h5
new file mode 100644
index 0000000000000000000000000000000000000000..00c7616fd4d2a1b2491e3427ef28d975e00724e3
--- /dev/null
+++ b/modelnet_c/add_local_3.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:029c794a19cffe80797950d2aac661fe2cc35b978d486e3684024a0bcf42c10c
+size 42194976
diff --git a/modelnet_c/add_local_4.h5 b/modelnet_c/add_local_4.h5
new file mode 100644
index 0000000000000000000000000000000000000000..83e677122926b6f8127e2c9d70ed91c22d444df8
--- /dev/null
+++ b/modelnet_c/add_local_4.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd8597a0a95f358cacb9f5ad8c1d76b74179988b4037b277dff7c58cd95c8146
+size 45156576
diff --git a/modelnet_c/clean.h5 b/modelnet_c/clean.h5
new file mode 100644
index 0000000000000000000000000000000000000000..21a97167e3c9140d8c9350b999a22d499d67e094
--- /dev/null
+++ b/modelnet_c/clean.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:203ec0037ebdee84f13df5abe12e3cf0e1047192832ff6149ac54b4deaf37931
+size 30348576
diff --git a/modelnet_c/dropout_global_0.h5 b/modelnet_c/dropout_global_0.h5
new file mode 100644
index 0000000000000000000000000000000000000000..131d5a22a10d278954cb34972a6b1f2d26d23d07
--- /dev/null
+++ b/modelnet_c/dropout_global_0.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8b7527cc2512fd83e074b3f814cb88253781b2331d9f3c79c54acc3a0f6e64e
+size 22766880
diff --git a/modelnet_c/dropout_global_1.h5 b/modelnet_c/dropout_global_1.h5
new file mode 100644
index 0000000000000000000000000000000000000000..c5e65246cdaad76292c49513523284b981a042e7
--- /dev/null
+++ b/modelnet_c/dropout_global_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:00c769d0f5f3bd7f9a4bf261d23b567fe75861b54c83671f5c46fe7ea288c79c
+size 18976032
diff --git a/modelnet_c/dropout_global_2.h5 b/modelnet_c/dropout_global_2.h5
new file mode 100644
index 0000000000000000000000000000000000000000..058426f99404ecb1e7cf5d4f77fb369679428a0b
--- /dev/null
+++ b/modelnet_c/dropout_global_2.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff6fe9780aee5fb4995e81313a3341df7b6c23d710e2bb0cd1bc46c9cb6e6815
+size 15185184
diff --git a/modelnet_c/dropout_global_3.h5 b/modelnet_c/dropout_global_3.h5
new file mode 100644
index 0000000000000000000000000000000000000000..56860aefd48be40805639044c23b28acba8a951e
--- /dev/null
+++ b/modelnet_c/dropout_global_3.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8d9c46fe1d02d779895135d42861c34988355893c4d81ba54586ebb7571b4cd
+size 11394336
diff --git a/modelnet_c/dropout_global_4.h5 b/modelnet_c/dropout_global_4.h5
new file mode 100644
index 0000000000000000000000000000000000000000..69fddc494b2fbde471b840fa02aed2d6bc91e6b5
--- /dev/null
+++ b/modelnet_c/dropout_global_4.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0342a91f9e16e29751e90296171c593e1eb35da0ffe3579a6664bb18c2f28d9b
+size 7603488
diff --git a/modelnet_c/dropout_local_0.h5 b/modelnet_c/dropout_local_0.h5
new file mode 100644
index 0000000000000000000000000000000000000000..67a06482a4824ce6fb7035ad53c8e5d53d4603d2
--- /dev/null
+++ b/modelnet_c/dropout_local_0.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8751d701b16e7f37065a1322c4008bff807db5b2f2ab5ce9254fa941a519742
+size 27386976
diff --git a/modelnet_c/dropout_local_1.h5 b/modelnet_c/dropout_local_1.h5
new file mode 100644
index 0000000000000000000000000000000000000000..7a323e142dbfc4e6a2c5ef3dbb2d6164e922ee8a
--- /dev/null
+++ b/modelnet_c/dropout_local_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2441fae91bc29fba8953799e11f604ec5beeb20913380ca225e1b9a8fe3b500
+size 24425376
diff --git a/modelnet_c/dropout_local_2.h5 b/modelnet_c/dropout_local_2.h5
new file mode 100644
index 0000000000000000000000000000000000000000..cad09b8bbc3218539fb3087bc1640577aa2c0bf8
--- /dev/null
+++ b/modelnet_c/dropout_local_2.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0fa9a2ba4146bc522d7731a495134babd19d8b212105530abbcd40256c7eaaed
+size 21463776
diff --git a/modelnet_c/dropout_local_3.h5 b/modelnet_c/dropout_local_3.h5
new file mode 100644
index 0000000000000000000000000000000000000000..b461ff53b382096542f1a77688bb13bdfb98890d
--- /dev/null
+++ b/modelnet_c/dropout_local_3.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afc6742314f168e7a53209e81ce4298485fb665578953e9f231eb80fc7dd9940
+size 18502176
diff --git a/modelnet_c/dropout_local_4.h5 b/modelnet_c/dropout_local_4.h5
new file mode 100644
index 0000000000000000000000000000000000000000..d919fe155fc9e9fe333a1f7e761a59303671fedd
--- /dev/null
+++ b/modelnet_c/dropout_local_4.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fc12bb058cb674274a0e01e79b3810cf693e85aaa0d8a21832a3f3108d8c69f
+size 15540576
diff --git a/modelnet_c/jitter_0.h5 b/modelnet_c/jitter_0.h5
new file mode 100644
index 0000000000000000000000000000000000000000..9765adefad5921e45a531e5ddb7603e57b99c3a0
--- /dev/null
+++ b/modelnet_c/jitter_0.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36b0e5c417d801b2b12d9a97cb4a998eb3d639d3a9bc7a714ee70dbda25d2284
+size 60675360
diff --git a/modelnet_c/jitter_1.h5 b/modelnet_c/jitter_1.h5
new file mode 100644
index 0000000000000000000000000000000000000000..8b0af7b5a2e232436a6126bfdca58d083b9f925f
--- /dev/null
+++ b/modelnet_c/jitter_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92b770e594a0c72fbd4b7394e02d2a6bf0490c44acb7474fc2e3837de0827c5d
+size 60675360
diff --git a/modelnet_c/jitter_2.h5 b/modelnet_c/jitter_2.h5
new file mode 100644
index 0000000000000000000000000000000000000000..c88e908606d68e9750ad3adddaad9d891f0db395
--- /dev/null
+++ b/modelnet_c/jitter_2.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d95bbc72d191f632a368873552bcaa15068082f424f184485b997d4e1ba6f05
+size 60675360
diff --git a/modelnet_c/jitter_3.h5 b/modelnet_c/jitter_3.h5
new file mode 100644
index 0000000000000000000000000000000000000000..dce56a9390536b9d0784eb9287041b9da2a60636
--- /dev/null
+++ b/modelnet_c/jitter_3.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed401170657a4326638787c71004960db5cc5bb03de56df64ef335d73c975032
+size 60675360
diff --git a/modelnet_c/jitter_4.h5 b/modelnet_c/jitter_4.h5
new file mode 100644
index 0000000000000000000000000000000000000000..4966a1a95d3cf0fae660a8313828efaba51b21b1
--- /dev/null
+++ b/modelnet_c/jitter_4.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4547327b47fa368814145f8397bf87061a86281afbf253f85c0a6d1abbd8251b
+size 60675360
diff --git a/modelnet_c/rotate_0.h5 b/modelnet_c/rotate_0.h5
new file mode 100644
index 0000000000000000000000000000000000000000..b351a0acabd7cd6a349a40f9ebdb8fe9a6abfeb4
--- /dev/null
+++ b/modelnet_c/rotate_0.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:95246d8c9506578039c4ab8f0cb1b8794d3d32c533e906aea724a04de8c3a1c7
+size 60675360
diff --git a/modelnet_c/rotate_1.h5 b/modelnet_c/rotate_1.h5
new file mode 100644
index 0000000000000000000000000000000000000000..f88f874d039863b0c1d30bfe9103f5bf9e56481c
--- /dev/null
+++ b/modelnet_c/rotate_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:70147de9808e5ed0cdcb0184a1ade44241497c3d69d74576e5dc1b333cfc87aa
+size 60675360
diff --git a/modelnet_c/rotate_2.h5 b/modelnet_c/rotate_2.h5
new file mode 100644
index 0000000000000000000000000000000000000000..be4454532c005be360aa3c382ba839ae8ace9d2a
--- /dev/null
+++ b/modelnet_c/rotate_2.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6ae9db2079eb728eb2f8cdc7dfdcb7ff5456575a986c7e10777cbcd29fb4fbb
+size 60675360
diff --git a/modelnet_c/rotate_3.h5 b/modelnet_c/rotate_3.h5
new file mode 100644
index 0000000000000000000000000000000000000000..92f8a6c291b322f01b4c8dca8826f7fdd2dacca1
--- /dev/null
+++ b/modelnet_c/rotate_3.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c13ab0be340e82c63ef325619759c71aae63a622a1846c95f0697152c3e7b1c
+size 60675360
diff --git a/modelnet_c/rotate_4.h5 b/modelnet_c/rotate_4.h5
new file mode 100644
index 0000000000000000000000000000000000000000..5951cb11081e6e827ed0ca083ddf1ebf424211de
--- /dev/null
+++ b/modelnet_c/rotate_4.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1d70d56407cb9a02308a34806bb908a1bba8db2779881a7a1a49ba7005119fb
+size 60675360
diff --git a/modelnet_c/scale_0.h5 b/modelnet_c/scale_0.h5
new file mode 100644
index 0000000000000000000000000000000000000000..dcb95d5fc16fc2d6f8bb874074c0bdde87d3f8a2
--- /dev/null
+++ b/modelnet_c/scale_0.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4bb6638a9330960261aadda500d34060ad00143322cc2a40f3b5dea55ad7a40
+size 30348576
diff --git a/modelnet_c/scale_1.h5 b/modelnet_c/scale_1.h5
new file mode 100644
index 0000000000000000000000000000000000000000..f870c1669c0228bc0b7936f5ca26a5f904a86d73
--- /dev/null
+++ b/modelnet_c/scale_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c5fa9823bcba93775fc681daad8625e2478c444a322ae689bb9d071e577e4b1d
+size 30348576
diff --git a/modelnet_c/scale_2.h5 b/modelnet_c/scale_2.h5
new file mode 100644
index 0000000000000000000000000000000000000000..a4daee3e25f6533a1773fd89acbf8127b7e63956
--- /dev/null
+++ b/modelnet_c/scale_2.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:127371bd0f05b89592886b7ce5defb357a27aafd630271ff8e456e63fd8890cb
+size 30348576
diff --git a/modelnet_c/scale_3.h5 b/modelnet_c/scale_3.h5
new file mode 100644
index 0000000000000000000000000000000000000000..40f216e0385b8f081f64992ee96ffbdaa1ec888c
--- /dev/null
+++ b/modelnet_c/scale_3.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0dc4b8ca33d05531aa97b9fe20cf1139716e7960de6415a202f7feea08bff2b6
+size 30348576
diff --git a/modelnet_c/scale_4.h5 b/modelnet_c/scale_4.h5
new file mode 100644
index 0000000000000000000000000000000000000000..2dadee178d99d735ae9486bc3f26a90894ffd24c
--- /dev/null
+++ b/modelnet_c/scale_4.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ca5396a60c178e54f13794780ba24609eba1eb5cdcf80bee3243a00c785d1bfd
+size 30348576
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9991a4b854534bc0087f99e683f255078ed2d436
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+mathutils
+matplotlib
+h5py
+torch
\ No newline at end of file
diff --git a/shape_names.txt b/shape_names.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1b2a39713de9f339485bda00ba7977aa55a640c7
--- /dev/null
+++ b/shape_names.txt
@@ -0,0 +1,40 @@
+airplane
+bathtub
+bed
+bench
+bookshelf
+bottle
+bowl
+car
+chair
+cone
+cup
+curtain
+desk
+door
+dresser
+flower_pot
+glass_box
+guitar
+keyboard
+lamp
+laptop
+mantel
+monitor
+night_stand
+person
+piano
+plant
+radio
+range_hood
+sink
+sofa
+stairs
+stool
+table
+tent
+toilet
+tv_stand
+vase
+wardrobe
+xbox
diff --git a/util/GDANet_util.py b/util/GDANet_util.py
new file mode 100755
index 0000000000000000000000000000000000000000..74c7b131e211e40fcc3d4901c0607f54b04aa2dc
--- /dev/null
+++ b/util/GDANet_util.py
@@ -0,0 +1,211 @@
+import torch
+from torch import nn
+
+
+def knn(x, k):
+ inner = -2*torch.matmul(x.transpose(2, 1), x)
+ xx = torch.sum(x**2, dim=1, keepdim=True)
+ pairwise_distance = -xx - inner - xx.transpose(2, 1)
+
+ idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
+ return idx, pairwise_distance
+
+
+def local_operator(x, k):
+ batch_size = x.size(0)
+ num_points = x.size(2)
+ x = x.view(batch_size, -1, num_points)
+ idx, _ = knn(x, k=k)
+ device = torch.device('cpu')
+ idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
+
+ idx = idx + idx_base
+
+ idx = idx.view(-1)
+
+ _, num_dims, _ = x.size()
+
+ x = x.transpose(2, 1).contiguous()
+
+ neighbor = x.view(batch_size * num_points, -1)[idx, :]
+
+ neighbor = neighbor.view(batch_size, num_points, k, num_dims)
+
+ x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+
+ feature = torch.cat((neighbor-x, neighbor), dim=3).permute(0, 3, 1, 2) # local and global all in
+
+ return feature
+
+
+def local_operator_withnorm(x, norm_plt, k):
+ batch_size = x.size(0)
+ num_points = x.size(2)
+ x = x.view(batch_size, -1, num_points)
+ norm_plt = norm_plt.view(batch_size, -1, num_points)
+ idx, _ = knn(x, k=k) # (batch_size, num_points, k)
+ device = torch.device('cpu')
+
+ idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
+
+ idx = idx + idx_base
+
+ idx = idx.view(-1)
+
+ _, num_dims, _ = x.size()
+
+ x = x.transpose(2, 1).contiguous()
+ norm_plt = norm_plt.transpose(2, 1).contiguous()
+
+ neighbor = x.view(batch_size * num_points, -1)[idx, :]
+ neighbor_norm = norm_plt.view(batch_size * num_points, -1)[idx, :]
+
+ neighbor = neighbor.view(batch_size, num_points, k, num_dims)
+ neighbor_norm = neighbor_norm.view(batch_size, num_points, k, num_dims)
+
+ x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+
+ feature = torch.cat((neighbor-x, neighbor, neighbor_norm), dim=3).permute(0, 3, 1, 2) # 3c
+
+ return feature
+
+
+def GDM(x, M):
+ """
+ Geometry-Disentangle Module
+ M: number of disentangled points in both sharp and gentle variation components
+ """
+ k = 64 # number of neighbors to decide the range of j in Eq.(5)
+ tau = 0.2 # threshold in Eq.(2)
+ sigma = 2 # parameters of f (Gaussian function in Eq.(2))
+ ###############
+ """Graph Construction:"""
+ device = torch.device('cpu')
+ batch_size = x.size(0)
+ num_points = x.size(2)
+ x = x.view(batch_size, -1, num_points)
+
+ idx, p = knn(x, k=k) # p: -[(x1-x2)^2+...]
+
+ # here we add a tau
+ p1 = torch.abs(p)
+ p1 = torch.sqrt(p1)
+ mask = p1 < tau
+
+ # here we add a sigma
+ p = p / (sigma * sigma)
+ w = torch.exp(p) # b,n,n
+ w = torch.mul(mask.float(), w)
+
+ b = 1/torch.sum(w, dim=1)
+ b = b.reshape(batch_size, num_points, 1).repeat(1, 1, num_points)
+ c = torch.eye(num_points, num_points, device=device)
+ c = c.expand(batch_size, num_points, num_points)
+ D = b * c # b,n,n
+
+ A = torch.matmul(D, w) # normalized adjacency matrix A_hat
+
+ # Get Aij in a local area:
+ idx2 = idx.view(batch_size * num_points, -1)
+ idx_base2 = torch.arange(0, batch_size * num_points, device=device).view(-1, 1) * num_points
+ idx2 = idx2 + idx_base2
+
+ idx2 = idx2.reshape(batch_size * num_points, k)[:, 1:k]
+ idx2 = idx2.reshape(batch_size * num_points * (k - 1))
+ idx2 = idx2.view(-1)
+
+ A = A.view(-1)
+ A = A[idx2].reshape(batch_size, num_points, k - 1) # Aij: b,n,k
+ ###############
+ """Disentangling Point Clouds into Sharp(xs) and Gentle(xg) Variation Components:"""
+ idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
+ idx = idx + idx_base
+ idx = idx.reshape(batch_size * num_points, k)[:, 1:k]
+ idx = idx.reshape(batch_size * num_points * (k - 1))
+
+ _, num_dims, _ = x.size()
+
+ x = x.transpose(2, 1).contiguous() # b,n,c
+ neighbor = x.view(batch_size * num_points, -1)[idx, :]
+ neighbor = neighbor.view(batch_size, num_points, k - 1, num_dims) # b,n,k,c
+ A = A.reshape(batch_size, num_points, k - 1, 1) # b,n,k,1
+ n = A.mul(neighbor) # b,n,k,c
+ n = torch.sum(n, dim=2) # b,n,c
+
+ pai = torch.norm(x - n, dim=-1).pow(2) # Eq.(5)
+ pais = pai.topk(k=M, dim=-1)[1] # first M points as the sharp variation component
+ paig = (-pai).topk(k=M, dim=-1)[1] # last M points as the gentle variation component
+
+ pai_base = torch.arange(0, batch_size, device=device).view(-1, 1) * num_points
+ indices = (pais + pai_base).view(-1)
+ indiceg = (paig + pai_base).view(-1)
+
+ xs = x.view(batch_size * num_points, -1)[indices, :]
+ xg = x.view(batch_size * num_points, -1)[indiceg, :]
+
+ xs = xs.view(batch_size, M, -1) # b,M,c
+ xg = xg.view(batch_size, M, -1) # b,M,c
+
+ return xs, xg
+
+
+class SGCAM(nn.Module):
+ """Sharp-Gentle Complementary Attention Module:"""
+ def __init__(self, in_channels, inter_channels=None, bn_layer=True):
+ super(SGCAM, self).__init__()
+
+ self.in_channels = in_channels
+ self.inter_channels = inter_channels
+
+ if self.inter_channels is None:
+ self.inter_channels = in_channels // 2
+ if self.inter_channels == 0:
+ self.inter_channels = 1
+
+ conv_nd = nn.Conv1d
+ bn = nn.BatchNorm1d
+
+ self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
+ kernel_size=1, stride=1, padding=0)
+
+ if bn_layer:
+ self.W = nn.Sequential(
+ conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
+ kernel_size=1, stride=1, padding=0),
+ bn(self.in_channels)
+ )
+ nn.init.constant(self.W[1].weight, 0)
+ nn.init.constant(self.W[1].bias, 0)
+ else:
+ self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
+ kernel_size=1, stride=1, padding=0)
+ nn.init.constant(self.W.weight, 0)
+ nn.init.constant(self.W.bias, 0)
+
+ self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
+ kernel_size=1, stride=1, padding=0)
+
+ self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
+ kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, x_2):
+ batch_size = x.size(0)
+
+ g_x = self.g(x_2).view(batch_size, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x_2).view(batch_size, self.inter_channels, -1)
+ W = torch.matmul(theta_x, phi_x) # Attention Matrix
+ N = W.size(-1)
+ W_div_C = W / N
+
+ y = torch.matmul(W_div_C, g_x)
+ y = y.permute(0, 2, 1).contiguous()
+ y = y.view(batch_size, self.inter_channels, *x.size()[2:])
+ W_y = self.W(y)
+ y = W_y + x
+
+ return y
+
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/util/__pycache__/GDANet_util.cpython-38.pyc b/util/__pycache__/GDANet_util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df02a6a97b3ad2bdbebcec46accd24605161ae20
Binary files /dev/null and b/util/__pycache__/GDANet_util.cpython-38.pyc differ
diff --git a/util/__pycache__/__init__.cpython-38.pyc b/util/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cfb1078b356b40ee96e649f2f39f60d6b65c645
Binary files /dev/null and b/util/__pycache__/__init__.cpython-38.pyc differ
diff --git a/util/data_util.py b/util/data_util.py
new file mode 100755
index 0000000000000000000000000000000000000000..24734cfc8419ed3f0db313d910854e259f09b1d0
--- /dev/null
+++ b/util/data_util.py
@@ -0,0 +1,165 @@
+import glob
+import h5py
+import numpy as np
+from torch.utils.data import Dataset
+import os
+import json
+from PointWOLF import PointWOLF
+
+
+def load_data(partition):
+ all_data = []
+ all_label = []
+ for h5_name in glob.glob('./data/modelnet40_ply_hdf5_2048/ply_data_%s*.h5' % partition):
+ f = h5py.File(h5_name)
+ data = f['data'][:].astype('float32')
+ label = f['label'][:].astype('int64')
+ f.close()
+ all_data.append(data)
+ all_label.append(label)
+ all_data = np.concatenate(all_data, axis=0)
+ all_label = np.concatenate(all_label, axis=0)
+ return all_data, all_label
+
+
+def pc_normalize(pc):
+ centroid = np.mean(pc, axis=0)
+ pc = pc - centroid
+ m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
+ pc = pc / m
+ return pc
+
+
+def translate_pointcloud(pointcloud):
+ xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
+ xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
+
+ translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
+ return translated_pointcloud
+
+
+def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
+ N, C = pointcloud.shape
+ pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
+ return pointcloud
+
+
+# =========== ModelNet40 =================
+class ModelNet40(Dataset):
+ def __init__(self, num_points, partition='train', args=None):
+ self.data, self.label = load_data(partition)
+ self.num_points = num_points
+ self.partition = partition
+ self.PointWOLF = PointWOLF(args) if args is not None else None
+
+
+ def __getitem__(self, item):
+ pointcloud = self.data[item][:self.num_points]
+ label = self.label[item]
+ if self.partition == 'train':
+ np.random.shuffle(pointcloud)
+ if self.PointWOLF is not None:
+ _, pointcloud = self.PointWOLF(pointcloud)
+ return pointcloud, label
+
+ def __len__(self):
+ return self.data.shape[0]
+
+# =========== ShapeNet Part =================
+class PartNormalDataset(Dataset):
+ def __init__(self, npoints=2500, split='train', normalize=False):
+ self.npoints = npoints
+ self.root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal'
+ self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
+ self.cat = {}
+ self.normalize = normalize
+
+ with open(self.catfile, 'r') as f:
+ for line in f:
+ ls = line.strip().split()
+ self.cat[ls[0]] = ls[1]
+ self.cat = {k: v for k, v in self.cat.items()}
+
+ self.meta = {}
+ with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
+ train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
+ with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
+ val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
+ with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
+ test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
+ for item in self.cat:
+ self.meta[item] = []
+ dir_point = os.path.join(self.root, self.cat[item])
+ fns = sorted(os.listdir(dir_point))
+
+ if split == 'trainval':
+ fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
+ elif split == 'train':
+ fns = [fn for fn in fns if fn[0:-4] in train_ids]
+ elif split == 'val':
+ fns = [fn for fn in fns if fn[0:-4] in val_ids]
+ elif split == 'test':
+ fns = [fn for fn in fns if fn[0:-4] in test_ids]
+ else:
+ print('Unknown split: %s. Exiting..' % (split))
+ exit(-1)
+
+ for fn in fns:
+ token = (os.path.splitext(os.path.basename(fn))[0])
+ self.meta[item].append(os.path.join(dir_point, token + '.txt'))
+
+ self.datapath = []
+ for item in self.cat:
+ for fn in self.meta[item]:
+ self.datapath.append((item, fn))
+
+ self.classes = dict(zip(self.cat, range(len(self.cat))))
+ # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
+ self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
+ 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
+ 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
+ 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
+ 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
+
+ self.cache = {} # from index to (point_set, cls, seg) tuple
+ self.cache_size = 20000
+
+ def __getitem__(self, index):
+ if index in self.cache:
+ point_set, normal, seg, cls = self.cache[index]
+ else:
+ fn = self.datapath[index]
+ cat = self.datapath[index][0]
+ cls = self.classes[cat]
+ cls = np.array([cls]).astype(np.int32)
+ data = np.loadtxt(fn[1]).astype(np.float32)
+ point_set = data[:, 0:3]
+ normal = data[:, 3:6]
+ seg = data[:, -1].astype(np.int32)
+ if len(self.cache) < self.cache_size:
+ self.cache[index] = (point_set, normal, seg, cls)
+
+ if self.normalize:
+ point_set = pc_normalize(point_set)
+
+ choice = np.random.choice(len(seg), self.npoints, replace=True)
+
+ # resample
+ # note that the number of points in some points clouds is less than 2048, thus use random.choice
+ # remember to use the same seed during train and test for a getting stable result
+ point_set = point_set[choice, :]
+ seg = seg[choice]
+ normal = normal[choice, :]
+
+ return point_set, cls, seg, normal
+
+ def __len__(self):
+ return len(self.datapath)
+
+
+if __name__ == '__main__':
+ train = ModelNet40(1024)
+ test = ModelNet40(1024, 'test')
+ for data, label in train:
+ print(data.shape)
+ print(label.shape)
diff --git a/util/util.py b/util/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..00afdd8435f42ad52480945285459c5cd3576d73
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,69 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def cal_loss(pred, gold, smoothing=True):
+ ''' Calculate cross entropy loss, apply label smoothing if needed. '''
+
+ gold = gold.contiguous().view(-1) # gold is the groudtruth label in the dataloader
+
+ if smoothing:
+ eps = 0.2
+ n_class = pred.size(1) # the number of feature_dim of the ouput, which is output channels
+
+ one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
+ log_prb = F.log_softmax(pred, dim=1)
+
+ loss = -(one_hot * log_prb).sum(dim=1).mean()
+ else:
+ loss = F.cross_entropy(pred, gold, reduction='mean')
+
+ return loss
+
+
+# create a file and write the text into it:
+class IOStream():
+ def __init__(self, path):
+ self.f = open(path, 'a')
+
+ def cprint(self, text):
+ print(text)
+ self.f.write(text+'\n')
+ self.f.flush()
+
+ def close(self):
+ self.f.close()
+
+
+def to_categorical(y, num_classes):
+ """ 1-hot encodes a tensor """
+ new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
+ if (y.is_cuda):
+ return new_y.cuda(non_blocking=True)
+ return new_y
+
+
+def compute_overall_iou(pred, target, num_classes):
+ shape_ious = []
+ pred = pred.max(dim=2)[1] # (batch_size, num_points) the pred_class_idx of each point in each sample
+ pred_np = pred.cpu().data.numpy()
+
+ target_np = target.cpu().data.numpy()
+ for shape_idx in range(pred.size(0)): # sample_idx
+ part_ious = []
+ for part in range(num_classes): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes
+ # for target, each point has a class no matter which category owns this point! also 50 classes!!!
+ # only return 1 when both belongs to this class, which means correct:
+ I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
+ # always return 1 when either is belongs to this class:
+ U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part))
+
+ F = np.sum(target_np[shape_idx] == part)
+
+ if F != 0:
+ iou = I / float(U) # iou across all points for this class
+ part_ious.append(iou) # append the iou of this class
+ shape_ious.append(np.mean(part_ious)) # each time append an average iou across all classes of this sample (sample_level!)
+ return shape_ious # [batch_size]