jihyeonj commited on
Commit
99ebece
·
verified ·
1 Parent(s): e214652

gradio updated

Browse files
Files changed (2) hide show
  1. app.py +258 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ import matplotlib.pyplot as plt
5
+ from functools import partial
6
+ import torch
7
+ from torch_geometric.nn import knn
8
+ import gradio as gr
9
+ import os
10
+ import glob
11
+ import pickle
12
+ import time
13
+ import cv2
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import plotly.graph_objects as go
17
+ import open3d as o3d
18
+ import trimesh
19
+ import gradio
20
+
21
+
22
+
23
+ def createPlane(normal, point_on_plane):
24
+ normal = normal / np.linalg.norm(normal)
25
+
26
+ # Find a vector in the plane
27
+ if np.allclose(normal, [1, 0, 0]):
28
+ v1 = np.cross(normal, [0, 1, 0])
29
+ else:
30
+ v1 = np.cross(normal, [1, 0, 0])
31
+
32
+ v1 = v1 / np.linalg.norm(v1)
33
+ v2 = np.cross(normal, v1)
34
+ v2 = v2 / np.linalg.norm(v2)
35
+
36
+ half_width = 1
37
+ half_height = 1
38
+
39
+ # Calculate the corners
40
+ corner1 = point_on_plane + half_width * v1 + half_height * v2
41
+ corner2 = point_on_plane - half_width * v1 + half_height * v2
42
+ corner3 = point_on_plane - half_width * v1 - half_height * v2
43
+ corner4 = point_on_plane + half_width * v1 - half_height * v2
44
+
45
+ vertices = np.array([corner1, corner2, corner3, corner4])
46
+
47
+ faces = np.array([
48
+ [0, 1, 2],
49
+ [0, 2, 3],
50
+ [2, 1, 0],
51
+ [3, 2, 0]
52
+ ])
53
+ # Define the color (sky blue) with opacity (alpha)
54
+ # Define the color (sky blue) with transparency
55
+ sky_blue_with_alpha = [255, 255, 255, 128] # RGBA format, with 128 alpha for half opacity
56
+
57
+ # Set the vertex colors with transparency
58
+ vertex_colors = np.tile(sky_blue_with_alpha, (vertices.shape[0], 1))
59
+
60
+ # Create a mesh for the rectangle
61
+ plane_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=vertex_colors)
62
+ return plane_mesh
63
+
64
+ def reflect_points_multiple_3d(d, n, p):
65
+ points_expanded = p.unsqueeze(0).expand(n.size(0), -1, -1)
66
+ normals_expanded = n.unsqueeze(1).expand(-1, p.size(0), -1)
67
+ distances_expanded = d.unsqueeze(1).expand(-1, p.size(0))
68
+ dot_products = torch.sum(points_expanded * normals_expanded, dim=2)
69
+ # reflections: (m, n, 3)
70
+ reflections = points_expanded - 2 * (dot_products - distances_expanded).unsqueeze(2) * normals_expanded
71
+
72
+ return reflections
73
+
74
+ def reflection_point_association_3d(d, n, q, threshold):
75
+ # d in shape (m, ), n in shape (m, 3), q in shape (n, 3)
76
+ reflections = reflect_points_multiple_3d(d, n, q) # shape: (m, n, 3)
77
+ reflections = reflections.view(-1, 3) # Flatten to (m*n, 3)
78
+
79
+ # Using knn to find the closest points in q for each point in reflections
80
+ # knn finds indices of the nearest neighbors
81
+ _, indices = knn(q, reflections, k=1, batch_x=None, batch_y=None) # indices shape: (m*n, 1)
82
+
83
+ # Gather nearest points based on indices from q
84
+ nearest_points = q[indices.squeeze()] # shape: (m*n, 3)
85
+
86
+ # Calculate distances for the nearest neighbors
87
+ distances = (nearest_points - reflections).norm(dim=1) # shape: (m*n,)
88
+
89
+ # Reshape distances back to (m, n) and check threshold
90
+ distances = distances.view(d.size(0), -1) # shape: (m, n)
91
+ within_threshold = distances <= threshold # shape: (m, n), sum this along axis 1 to get the number of associated points
92
+ return within_threshold
93
+
94
+ def get_patches(points, centroids):
95
+ norm = np.linalg.norm(centroids, axis=1)
96
+ n = centroids / norm[:, None]
97
+ d = norm - 1
98
+
99
+ association = reflection_point_association_3d(torch.tensor(np.array(d)), torch.tensor(np.array(n)),
100
+ torch.tensor(np.array(points)), 0.03)
101
+
102
+ return np.array(association)
103
+
104
+ def left_right(allpoints, patchbool, planepoints):
105
+ """
106
+ inputs: patchpoints: (n,3)
107
+ planepoints: (4,3)
108
+ outputs: leftpoints, rightpoints: (k,3), (k',3)
109
+ """
110
+ def signed_distance(point, plane_point, normal):
111
+ return np.dot(normal, point - plane_point) / (np.linalg.norm(normal)+1e-6)
112
+ patchpoints = allpoints[patchbool]
113
+ p1 = planepoints[0]
114
+ p2 = planepoints[1]
115
+ p3 = planepoints[2]
116
+ v1 = p2 - p1
117
+ v2 = p3 - p1
118
+ normal = np.cross(v1, v2)
119
+
120
+ distances = np.array([signed_distance(point, p1, normal) for point in patchpoints])
121
+ contains_nan = np.isnan(distances).any()
122
+ l_idx = distances<0
123
+ allidcs = np.arange(len(allpoints))[patchbool]
124
+ left_idx = allidcs[l_idx]
125
+ right_idx = allidcs[~l_idx]
126
+
127
+ #left_points = patchpoints[l_idx]
128
+ #right_points = patchpoints[~l_idx]
129
+
130
+
131
+ return left_idx, right_idx#left_points, right_points
132
+
133
+ def dbscan(D, eps, MinPts):
134
+ labels = [0]*len(D)
135
+ C = 0
136
+ for P in range(0, len(D)):
137
+ if not (labels[P] == 0):
138
+ continue
139
+ NeighborPts = region_query(D, P, eps)
140
+ if len(NeighborPts) < MinPts:
141
+ labels[P] = -1
142
+ else:
143
+ C += 1
144
+ grow_cluster(D, labels, P, NeighborPts, C, eps, MinPts)
145
+ return labels
146
+
147
+ def grow_cluster(D, labels, P, NeighborPts, C, eps, MinPts):
148
+ labels[P] = C
149
+ i = 0
150
+ while i < len(NeighborPts):
151
+ Pn = NeighborPts[i]
152
+ if labels[Pn] == -1:
153
+ labels[Pn] = C
154
+ elif labels[Pn] == 0:
155
+ labels[Pn] = C
156
+ PnNeighborPts = region_query(D, Pn, eps)
157
+ if len(PnNeighborPts) >= MinPts:
158
+ NeighborPts = NeighborPts + PnNeighborPts
159
+
160
+ i += 1
161
+
162
+ def region_query(D, P, eps):
163
+ neighbors = []
164
+ for Pn in range(0, len(D)):
165
+ #if geodesic_dist(D[P], D[Pn])<eps:
166
+ if np.linalg.norm(D[P] - D[Pn]) < eps:
167
+ neighbors.append(Pn)
168
+
169
+ return neighbors
170
+
171
+ def compute_centroids(data, labels):
172
+ unique = np.unique(labels)
173
+ unique_labels = unique[unique!=-1]
174
+ centroids = []
175
+ for label in unique_labels:
176
+ mask = labels == label
177
+ points = data[mask]
178
+ centroid = jnp.mean(points, axis=0)
179
+ centroids.append(centroid)
180
+ return np.stack(centroids)
181
+
182
+
183
+ def proc_all(mesh_path, mode_path):
184
+ with open(mode_path, 'rb') as f:
185
+ fin = pickle.load(f)
186
+ name = mesh_path.split('/')[-1].split('.')[0]
187
+ mesh = trimesh.load(mesh_path)
188
+ verts = np.array(mesh.vertices)
189
+
190
+ my_labels = dbscan(fin, eps=0.1, MinPts=1)
191
+ centroid = np.array(compute_centroids(fin, my_labels))
192
+ print("total number of modes found: " + str(len(centroid)))
193
+
194
+ norm = torch.norm(torch.tensor(np.array(centroid)), dim=-1)
195
+ n = centroid / norm[:,None]
196
+ d = norm - 1
197
+ point = n * d[...,None]
198
+
199
+ #pts = create_points_3d(fin, alpha = 1, markersize = 4, label = 'final timestep')
200
+ #pts2 = create_points_3d(centroid, alpha = 1, markersize = 10, label = 'final timestep')
201
+ #plot_all_3d([pts, pts2])
202
+
203
+ pats = get_patches(torch.tensor(np.array(verts)), torch.tensor(centroid))
204
+ alldicts = []
205
+ pntdict = {}
206
+ for i in range(len(n)):
207
+ plane = createPlane(n[i], point[i])
208
+ l,r = left_right(verts, pats[i], plane.vertices)
209
+ single = {'plane': plane, 'left': l, 'right': r}
210
+ alldicts.append(single)
211
+ return alldicts
212
+
213
+ def create_scene(mesh_path, dicts, plane_idx):
214
+ mesh = trimesh.load(mesh_path)
215
+ verts = np.array(mesh.vertices)
216
+ colors = []
217
+ for i in range(len(verts)):
218
+ if i in dicts[plane_idx]['left']:
219
+ color = [8, 136, 255]
220
+ elif i in dicts[plane_idx]['right']:
221
+ color = [204, 20, 245]
222
+ else:
223
+ color = [225, 225, 225]
224
+ colors.append(color)
225
+
226
+ mesh.visual.vertex_colors = colors
227
+ newplane = dicts[plane_idx]['plane']
228
+ newplane.visual.vertex_colors = np.tile([100, 100, 100, 100], (4, 1))
229
+ scene = trimesh.Scene([mesh, newplane])
230
+ temp_file = f"/tmp/scene_{plane_idx}.obj"
231
+ scene.export(temp_file)
232
+ return temp_file
233
+
234
+ def load_mesh(mesh_file_name, mode_file_name):
235
+
236
+ dicts = proc_all(mesh_file_name, mode_file_name) # Assuming proc_all is defined and returns a list of dictionaries
237
+ allmesh = []
238
+ for i in range(min(5, len(dicts))): # Limiting to maximum 5 outputs
239
+ temp_file = create_scene(mesh_file_name, dicts, i)
240
+ allmesh.append(temp_file)
241
+ return allmesh + [None] * (5 - len(allmesh)) # Fill the rest with None if less than 5
242
+
243
+ examples = [["mesh_rescaled/" + mesh, "out_3d/" + mesh.split('.')[0] + "_50000_0.1_modes.pickle"] for mesh in os.listdir("mesh_rescaled/")]
244
+
245
+ outputs = [gr.Model3D(label=f"3D Model {i+1}") for i in range(10)]
246
+
247
+ demo = gr.Interface(
248
+ fn=load_mesh,
249
+ inputs=[gr.File(label="Upload Mesh"), gr.File(label="Upload Mode File")],
250
+ outputs=outputs,
251
+ examples=examples,
252
+ cache_examples=False # Set to False to ensure it doesn't cache examples, requires actual file uploads
253
+ )
254
+
255
+ if __name__ == "__main__":
256
+ demo.launch()
257
+
258
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ trimesh
2
+ open3d
3
+ torch_geometric
4
+ pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html