ZiyuG commited on
Commit
e64d541
1 Parent(s): b5b50d2

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +5 -26
evaluate.py CHANGED
@@ -70,14 +70,13 @@ def load_json(path):
70
  def eval(test, standard, tmpdir):
71
  test_p = tmpdir + "/user.mp4"
72
  standard_p = tmpdir + "/standard.mp4"
73
- os.system('python inferencer_demo.py ' + test_p + ' --pred-out-dir', tmpdir) # produce user.json
74
 
75
  scores = []
76
 
77
  align_filter(tmpdir + '/standard', tmpdir + '/user', tmpdir) # 帧对齐 produce aligned vedios
78
 
79
- # return None
80
- data_00 = load_json(tmpdir + '/standard.json') # aligned json
81
  data_01 = load_json(tmpdir + '/user.json')
82
  cap_00 = cv2.VideoCapture(standard_p)
83
  cap_01 = cv2.VideoCapture(test_p)
@@ -160,7 +159,6 @@ def eval(test, standard, tmpdir):
160
  start_point = (int(keypoints_01[start][0]), int(keypoints_01[start][1]))
161
  end_point = (int(keypoints_01[end][0]), int(keypoints_01[end][1]))
162
  cur_score = findcos_single([[int(keypoints_01[start][0]), int(keypoints_01[start][1])], [int(keypoints_01[end][0]), int(keypoints_01[end][1])]], [[int(keypoints_00_ori[start][0]), int(keypoints_00_ori[start][1])], [int(keypoints_00_ori[end][0]), int(keypoints_00_ori[end][1])]])
163
- # print(cur_score[0])
164
 
165
  # 如果当前相似度小于 99.3,认为有误差,并记录下来
166
  if float(cur_score[0]) < 98.8 and start != 5:
@@ -171,7 +169,6 @@ def eval(test, standard, tmpdir):
171
  bigerror.append(start)
172
  else:
173
  cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line
174
- # cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line
175
 
176
  for (start, end) in connections2:
177
  start = start - 1
@@ -180,9 +177,9 @@ def eval(test, standard, tmpdir):
180
  if i < len(keypoints_01) and i + 1 < len(keypoints_01):
181
  start_point = (int(keypoints_01[i][0]), int(keypoints_01[i][1]))
182
  end_point = (int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1]))
183
- # cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line
184
  cur_score = findcos_single([[int(keypoints_01[i][0]), int(keypoints_01[i][1])], [int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1])]], [[int(keypoints_00_ori[i][0]), int(keypoints_00_ori[i][1])], [int(keypoints_00_ori[i + 1][0]), int(keypoints_00_ori[i + 1][1])]])
185
- # print(cur_score[0])
186
  if float(cur_score[0]) < 98.8:
187
  error.append(start)
188
  cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2) # Red line
@@ -202,23 +199,12 @@ def eval(test, standard, tmpdir):
202
 
203
  if frame_id_00 < min_length and frame_id_01 < min_length:
204
  min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"])
205
- # print(min_cos)
206
- # if min_cos < 99:
207
- # cv2.putText(combined_frame, "Incorrect Gesture", (120, 220),
208
- # cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
209
- # else:
210
- # cv2.putText(combined_frame, "Correct Gesture", (120, 220),
211
- # cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
212
 
213
  # 如果存在误差,将误差部分对应的人体部位加入内容列表
214
  if error != []:
215
  # print(error)
216
  content = []
217
  for i in error:
218
- # if i in [5,7]: content.append('Right Arm')
219
- # if i in [6,8]: content.append('Left Arm')
220
- # if i > 90 and i < 112: content.append('Right Hand')
221
- # if i >= 112: content.append('Left Hand')
222
  if i in [5,7]: content.append('Left Arm')
223
  if i in [6,8]: content.append('Right Arm')
224
  if i > 90 and i < 112: content.append('Left Hand')
@@ -226,7 +212,6 @@ def eval(test, standard, tmpdir):
226
  part = ""
227
 
228
  # 在视频帧上显示检测到的误差部位
229
- # cv2.putText(combined_frame, "Please check: ", (430, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 255), line_width)
230
  cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
231
  start_x = int(frame_width*1.75) + 10 #435 # 起始的 x 坐标
232
  start_y = int(frame_height*0.2) + 50 # 45
@@ -242,10 +227,6 @@ def eval(test, standard, tmpdir):
242
  if bigerror != []:
243
  bigcontent = []
244
  for i in bigerror:
245
- # if i in [5,7]: bigcontent.append('Right Arm')
246
- # if i in [6,8]: bigcontent.append('Left Arm')
247
- # if i > 90 and i < 112: bigcontent.append('Right Hand')
248
- # if i >= 112: bigcontent.append('Left Hand')
249
  if i in [5,7]: bigcontent.append('Left Arm')
250
  if i in [6,8]: bigcontent.append('Right Arm')
251
  if i > 90 and i < 112: bigcontent.append('Left Hand')
@@ -257,10 +238,8 @@ def eval(test, standard, tmpdir):
257
  cnt += 1
258
  combined_frame = np.vstack((combined_frame_ori, combined_frame))
259
  out.write(combined_frame)
260
- # print(f"min_cos: {float(min_cos)}")
261
  scores.append(float(min_cos)) # 记录每一帧的相似度得分
262
 
263
- # print(f"scores: {scores}")
264
  fps = 5 # Frames per second
265
  frame_numbers = list(error_dict.keys()) # List of frame numbers 获取含有严重误差的帧号列表
266
  time_intervals = [(frame / fps, (frame + 1) / fps) for frame in frame_numbers] # 将帧号转换为时间区间(秒)
@@ -272,4 +251,4 @@ def eval(test, standard, tmpdir):
272
  # 1. scores 的平均值,作为整体手势相似度的评分
273
  # 2. final_merged_intervals,合并后的误差时间区间及其对应的误差信息
274
  # 3. comments,用于给用户的速度建议(加快或放慢手势)
275
- return sum(scores) / len(scores), final_merged_intervals, comments
 
70
  def eval(test, standard, tmpdir):
71
  test_p = tmpdir + "/user.mp4"
72
  standard_p = tmpdir + "/standard.mp4"
73
+ os.system('python inferencer_demo.py ' + test_p + ' --pred-out-dir ' + tmpdir) # produce user.json
74
 
75
  scores = []
76
 
77
  align_filter(tmpdir + '/standard', tmpdir + '/user', tmpdir) # 帧对齐 produce aligned vedios
78
 
79
+ data_00 = load_json(tmpdir + '/standard.json')
 
80
  data_01 = load_json(tmpdir + '/user.json')
81
  cap_00 = cv2.VideoCapture(standard_p)
82
  cap_01 = cv2.VideoCapture(test_p)
 
159
  start_point = (int(keypoints_01[start][0]), int(keypoints_01[start][1]))
160
  end_point = (int(keypoints_01[end][0]), int(keypoints_01[end][1]))
161
  cur_score = findcos_single([[int(keypoints_01[start][0]), int(keypoints_01[start][1])], [int(keypoints_01[end][0]), int(keypoints_01[end][1])]], [[int(keypoints_00_ori[start][0]), int(keypoints_00_ori[start][1])], [int(keypoints_00_ori[end][0]), int(keypoints_00_ori[end][1])]])
 
162
 
163
  # 如果当前相似度小于 99.3,认为有误差,并记录下来
164
  if float(cur_score[0]) < 98.8 and start != 5:
 
169
  bigerror.append(start)
170
  else:
171
  cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line
 
172
 
173
  for (start, end) in connections2:
174
  start = start - 1
 
177
  if i < len(keypoints_01) and i + 1 < len(keypoints_01):
178
  start_point = (int(keypoints_01[i][0]), int(keypoints_01[i][1]))
179
  end_point = (int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1]))
180
+
181
  cur_score = findcos_single([[int(keypoints_01[i][0]), int(keypoints_01[i][1])], [int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1])]], [[int(keypoints_00_ori[i][0]), int(keypoints_00_ori[i][1])], [int(keypoints_00_ori[i + 1][0]), int(keypoints_00_ori[i + 1][1])]])
182
+
183
  if float(cur_score[0]) < 98.8:
184
  error.append(start)
185
  cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2) # Red line
 
199
 
200
  if frame_id_00 < min_length and frame_id_01 < min_length:
201
  min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"])
 
 
 
 
 
 
 
202
 
203
  # 如果存在误差,将误差部分对应的人体部位加入内容列表
204
  if error != []:
205
  # print(error)
206
  content = []
207
  for i in error:
 
 
 
 
208
  if i in [5,7]: content.append('Left Arm')
209
  if i in [6,8]: content.append('Right Arm')
210
  if i > 90 and i < 112: content.append('Left Hand')
 
212
  part = ""
213
 
214
  # 在视频帧上显示检测到的误差部位
 
215
  cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
216
  start_x = int(frame_width*1.75) + 10 #435 # 起始的 x 坐标
217
  start_y = int(frame_height*0.2) + 50 # 45
 
227
  if bigerror != []:
228
  bigcontent = []
229
  for i in bigerror:
 
 
 
 
230
  if i in [5,7]: bigcontent.append('Left Arm')
231
  if i in [6,8]: bigcontent.append('Right Arm')
232
  if i > 90 and i < 112: bigcontent.append('Left Hand')
 
238
  cnt += 1
239
  combined_frame = np.vstack((combined_frame_ori, combined_frame))
240
  out.write(combined_frame)
 
241
  scores.append(float(min_cos)) # 记录每一帧的相似度得分
242
 
 
243
  fps = 5 # Frames per second
244
  frame_numbers = list(error_dict.keys()) # List of frame numbers 获取含有严重误差的帧号列表
245
  time_intervals = [(frame / fps, (frame + 1) / fps) for frame in frame_numbers] # 将帧号转换为时间区间(秒)
 
251
  # 1. scores 的平均值,作为整体手势相似度的评分
252
  # 2. final_merged_intervals,合并后的误差时间区间及其对应的误差信息
253
  # 3. comments,用于给用户的速度建议(加快或放慢手势)
254
+ return sum(scores) / len(scores), final_merged_intervals, comments