File size: 4,328 Bytes
6f08eef 2633f6b 4a7e4e0 2633f6b 4a7e4e0 d7cb5e4 6f08eef 4a7e4e0 d7cb5e4 4a7e4e0 2633f6b 6f08eef d7cb5e4 4a7e4e0 6f08eef 4a7e4e0 6f08eef 4a7e4e0 6f08eef 4a7e4e0 6f08eef 4a7e4e0 6f08eef 2633f6b 6f08eef 2633f6b 6f08eef d7cb5e4 6f08eef 4a7e4e0 6f08eef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import numpy as np
def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
mu_target = target_verts.mean(axis=0)
mu_in = verts_to_transform.mean(axis=0)
std_target = np.std(target_verts, axis=0)
std_in = np.std(verts_to_transform, axis=0)
if np.any(std_in == 0):
std_in[std_in == 0] = 1
if np.any(std_target == 0):
std_target[std_target == 0] = 1
if np.any(np.isnan(std_in)):
std_in[np.isnan(std_in)] = 1
if np.any(np.isnan(std_target)):
std_target[np.isnan(std_target)] = 1
if single_scale:
std_target = np.linalg.norm(std_target)
std_in = np.linalg.norm(std_in)
transformed_verts = (verts_to_transform - mu_in) / std_in
transformed_verts = transformed_verts * std_target + mu_target
return transformed_verts
def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1, ce=1.0, normalized=True, preregister=True, single_scale=True):
'''The function computes the Weighted Edge Distance (WED) between two graphs.
pd_vertices: list of predicted vertices
pd_edges: list of predicted edges
gt_vertices: list of ground truth vertices
gt_edges: list of ground truth edges
cv: vertex cost
ce: edge cost
normalized: if True, the WED is normalized by the total length of the ground truth edges
preregister: if True, the predicted vertices are pre-registered to the ground truth vertices
'''
# vertex coordinates are in centimeters, so cv and ce are set to 100.0 and 1.0 respectively.
# This means the missing a vertex is equivanlent predicting it 1 meters off,
# and that is the same as cv and ce equal to 1.0, if GT is in meters
pd_vertices = np.array(pd_vertices)
gt_vertices = np.array(gt_vertices)
diameter = cdist(gt_vertices, gt_vertices).max()
if cv < 0:
cv = diameter / 4.0
# Cost of addining or deleting a vertex is set to 1/4 of the diameter of the ground truth mesh
# Step 0: Prenormalize / preregister
if preregister:
pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
pd_edges = np.array(pd_edges)
gt_edges = np.array(gt_edges)
# Step 1: Bipartite Matching
distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
row_ind, col_ind = linear_sum_assignment(distances)
# Step 2: Vertex Translation
translation_costs = np.sum(distances[row_ind, col_ind])
# Additional: Vertex Deletion
unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
deletion_costs = cv * len(unmatched_pd_indices)
# Step 3: Vertex Insertion
unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
insertion_costs = cv * len(unmatched_gt_indices)
# Step 4: Edge Deletion and Insertion
updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
pd_edges_set = set(map(tuple, [set(edge) for edge in updated_pd_edges]))
gt_edges_set = set(map(tuple, [set(edge) for edge in gt_edges]))
# Delete edges not in ground truth
edges_to_delete = pd_edges_set - gt_edges_set
vert_tf = [np.where(col_ind == v)[0][0] if v in col_ind else 0 for v in range(len(gt_vertices))]
deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[vert_tf[edge[0]]] - pd_vertices[vert_tf[edge[1]]]) for edge in edges_to_delete)
# Insert missing edges from ground truth
edges_to_insert = gt_edges_set - pd_edges_set
insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert)
# Step 5: Calculation of WED
WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs
if normalized:
total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
WED = WED / total_length_of_gt_edges
# print ("Total length", total_length_of_gt_edges)
return WED |