Spaces:
Runtime error
Runtime error
File size: 7,785 Bytes
753fd9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import torch
class LossGConMesh(torch.nn.Module):
def __init__(self , n_verts=3889):
super(LossGConMesh, self).__init__()
self.n_verts = n_verts
self.criterion_class = torch.nn.CrossEntropyLoss(reduction='mean')
def forward(self, pred_gc, target_gc, has_gc, loss_type_gcmesh='ce'):
# pred_gc has shape (bs, n_verts, 2)
# target_gc has shape (bs, n_verts, 3)
# with [first: no-contact=0 contact=1
# second: index of closest vertex with opposite label
# third: dist to that closest vertex]
target_gc_class = target_gc[:, :, 0]
target_gc_nearoppvert_ind = target_gc[:, :, 1]
target_gc_nearoppvert_dist = target_gc[:, :, 2]
# bs = pred_gc.shape[0]
bs = has_gc.sum()
if loss_type_gcmesh == 'ce': # cross entropy
# import pdb; pdb.set_trace()
# classification_loss = self.criterion_class(pred_gc.reshape((bs*self.n_verts, 2)), target_gc_class.reshape((bs*self.n_verts)))
classification_loss = self.criterion_class(pred_gc[has_gc==True, ...].reshape((bs*self.n_verts, 2)), target_gc_class[has_gc==True, ...].reshape((bs*self.n_verts)))
loss = classification_loss
else:
raise ValueError
return loss
def calculate_plane_errors_batch(vertices, target_gc_class, target_has_gc, has_gc_is_touching, return_error_under_plane=True):
# remarks:
# visualization of the plane: debug_code/curve_fitting_v2.py
# theory: https://www.ltu.se/cms_fs/1.51590!/svd-fitting.pdf
# remark: torch.svd is depreciated
# new plane equation:
# a(x−x0)+b(y−y0)+c(z−z0)=0
# ax+by+cz=d with d=ax0+by0+cz0
# z = (d-ax-by)/c
# here:
# a, b, c describe the plane normal
# d can be calculated (from a, b, c, x0, y0, z0)
# (x0, y0, z0) are the coordinates of a point on the
# plane, for example points_centroid
# (x, y, z) are the coordinates of a query point on the plane
#
# input:
# vertices: (bs, 3889, 3)
# target_gc_class: (bs, 3889)
#
bs = vertices.shape[0]
error_list = []
error_under_plane_list = []
for ind_b in range(bs):
if target_has_gc[ind_b] == 1 and has_gc_is_touching[ind_b] == 1:
try:
points_npx3 = vertices[ind_b, target_gc_class[ind_b, :]==1, :]
points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
points_centroid = torch.mean(points, dim=1)
input_svd = points - points_centroid[:, None]
# U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
# plane_normal = U_svd[:, 2]
# _, sigma_svd, _ = torch.svd(input_svd, compute_uv=False)
# _, sigma_svd, _ = torch.svd(input_svd, compute_uv=True)
U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
plane_squaredsumofdists = sigma_svd[2]
error_list.append(plane_squaredsumofdists)
if return_error_under_plane:
# plane information
# plane_centroid = points_centroid
plane_normal = U_svd[:, 2]
# non-plane points
nonplane_points_npx3 = vertices[ind_b, target_gc_class[ind_b, :]==0, :] # (n_points_3)
nonplane_points = torch.transpose(nonplane_points_npx3, 0, 1) # (3, n_points)
nonplane_points_centered = nonplane_points - points_centroid[:, None]
nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered) # plane normal already has length 1
if nonplane_points_projected.sum() > 0:
# bug corrected 07.11.22
# error_under_plane = nonplane_points_projected[nonplane_points_projected<0].sum() / 100
error_under_plane = - nonplane_points_projected[nonplane_points_projected<0].sum() / 100
else:
error_under_plane = nonplane_points_projected[nonplane_points_projected>0].sum() / 100
error_under_plane_list.append(error_under_plane)
except:
print('was not able to calculate plane error for this image')
error_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
error_under_plane_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
else:
error_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
error_under_plane_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
errors = torch.stack(error_list, dim=0)
errors_under_plane = torch.stack(error_under_plane_list, dim=0)
if return_error_under_plane:
return errors, errors_under_plane
else:
return errors
# def calculate_vertex_wise_labeling_error():
# vertexwise_ground_contact
'''
def paws_to_groundplane_error_batch(vertices, return_details=False):
# list of feet vertices (some of them)
# remark: we did annotate left indices and find the right insices using sym_ids_dict
# REMARK: this loss is not yet for batches!
import pdb; pdb.set_trace()
print('this loss is not yet for batches!')
list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
assert vertices.shape[0] == 3889
assert vertices.shape[1] == 3
all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
verts_paws = vertices[all_paw_vert_idxs, :]
plane_centroid, plane_normal, error = fit_plane_batch(verts_paws)
if return_details:
return plane_centroid, plane_normal, error
else:
return error
def paws_to_groundplane_error_batch_new(vertices, return_details=False):
# list of feet vertices (some of them)
# remark: we did annotate left indices and find the right insices using sym_ids_dict
# REMARK: this loss is not yet for batches!
import pdb; pdb.set_trace()
print('this loss is not yet for batches!')
list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
assert vertices.shape[0] == 3889
assert vertices.shape[1] == 3
all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
verts_paws = vertices[all_paw_vert_idxs, :]
plane_centroid, plane_normal, error = fit_plane_batch(verts_paws)
print('this loss is not yet for batches!')
points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
points_centroid = torch.mean(points, dim=1)
input_svd = points - points_centroid[:, None]
U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
plane_normal = U_svd[:, 2]
plane_squaredsumofdists = sigma_svd[2]
error = plane_squaredsumofdists
print('error: ' + str(error.item()))
return error
''' |