File size: 4,950 Bytes
6f08eef
 
 
 
a842b44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91b350c
a842b44
 
 
 
 
91b350c
a842b44
 
 
 
 
 
 
 
91b350c
a842b44
6f08eef
 
a842b44
 
 
 
91b350c
 
 
 
 
 
 
 
 
a842b44
6f08eef
a842b44
6f08eef
 
a842b44
6f08eef
a842b44
6f08eef
 
a842b44
6f08eef
 
 
a842b44
6f08eef
 
 
a842b44
6f08eef
 
a842b44
 
 
 
6f08eef
 
 
a842b44
 
 
 
6f08eef
 
 
 
 
 
 
a842b44
6f08eef
 
 
 
a842b44
 
6f08eef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import numpy as np 


def preregister_mean_std(verts_to_transform, target_verts, single_scale=True):
    mu_target = target_verts.mean(axis=0)
    mu_in = verts_to_transform.mean(axis=0)
    std_target = np.std(target_verts, axis=0)
    std_in = np.std(verts_to_transform, axis=0)
    
    if np.any(std_in == 0):
        std_in[std_in == 0] = 1
    if np.any(std_target == 0):
        std_target[std_target == 0] = 1
    if np.any(np.isnan(std_in)):
        std_in[np.isnan(std_in)] = 1
    if np.any(np.isnan(std_target)):
        std_target[np.isnan(std_target)] = 1
        
    if single_scale:
        std_target = np.linalg.norm(std_target)
        std_in = np.linalg.norm(std_in)
    
    transformed_verts = (verts_to_transform - mu_in) / std_in
    transformed_verts = transformed_verts * std_target + mu_target
    
    return transformed_verts


def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=-1/4, ce=1.0, normalized=True, preregister=True, single_scale=True):
    '''The function computes the Wireframe Edge Distance (WED) between two graphs.
    pd_vertices: list of predicted vertices
    pd_edges: list of predicted edges
    gt_vertices: list of ground truth vertices
    gt_edges: list of ground truth edges
    cv: vertex cost: if positive, the cost in centimeters of missing a vertex, if negative, multiplies diameter to compute cost (default is -1/2)
    ce: edge cost (multiplier of the edge length for edge deletion and insertion, default is 1.0)
    normalized: if True, the WED is normalized by the total length of the ground truth edges
    preregister: if True, the predicted vertices have their mean and scale matched to the ground truth vertices
    '''
    
    # Vertex coordinates are in centimeters. When cv and ce are set to 100.0 and 1.0 respectively, 
    # missing a vertex is equivanlent predicting it 1 meter away from the ground truth vertex.
    # This is equivalent to setting cv=1 and ce=1 when the vertex coordinates are in meters.
    # When a negative cv value is set (the default behavior), cv is reset to 1/2 of the diameter of the ground truth wireframe.
    
    pd_vertices = np.array(pd_vertices)
    gt_vertices = np.array(gt_vertices)
    
    if preregister:
        pd_vertices = preregister_mean_std(pd_vertices, gt_vertices, single_scale=single_scale)
        
    if cv < 0:
        diameter = cdist(gt_vertices, gt_vertices).max()
        # Cost of adding or deleting a vertex is set to -cv times the diameter of the ground truth wireframe
        cv = -cv * diameter
    elif cv == 0:
        # Cost of adding or deleting a vertex is set to the average distance of the ground truth vertices from their mean
        cv = np.linalg.norm(np.mean(gt_vertices, axis=0) - gt_vertices, axis=1).mean()
    # Step 0: Prenormalize / preregister
        

    pd_edges = np.array(pd_edges)
    gt_edges = np.array(gt_edges)        
    
    # Step 1: Bipartite Matching
    distances = cdist(pd_vertices, gt_vertices, metric='euclidean')
    row_ind, col_ind = linear_sum_assignment(distances)
        
    
    # Step 2: Vertex Translation
    translation_costs = np.sum(distances[row_ind, col_ind])
    
    # Additional: Vertex Deletion
    unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind)
    deletion_costs = cv * len(unmatched_pd_indices)  
    
    # Step 3: Vertex Insertion
    unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind)
    insertion_costs = cv * len(unmatched_gt_indices)  
    
    # Step 4: Edge Deletion and Insertion
    updated_pd_edges = [(col_ind[np.where(row_ind == edge[0])[0][0]], col_ind[np.where(row_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in row_ind and edge[1] in row_ind]
    pd_edges_set = set(map(tuple, [set(edge) for edge in updated_pd_edges]))
    gt_edges_set = set(map(tuple, [set(edge) for edge in gt_edges]))

    
    # Delete edges not in ground truth
    edges_to_delete = pd_edges_set - gt_edges_set
    
    vert_tf = [np.where(col_ind == v)[0][0] if v in col_ind else 0 for v in range(len(gt_vertices))]
    deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[vert_tf[edge[0]]] - pd_vertices[vert_tf[edge[1]]]) for edge in edges_to_delete)

    
    # Insert missing edges from ground truth
    edges_to_insert = gt_edges_set - pd_edges_set
    insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert) 
    
    # Step 5: Calculation of WED
    WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs

    
    if normalized:
        total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum()
        WED = WED / total_length_of_gt_edges
        
    # print ("Total length", total_length_of_gt_edges)
    return WED