ZiyuG commited on
Commit
cd92698
1 Parent(s): f08d9ef

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +42 -28
evaluate.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  from sklearn.preprocessing import Normalizer
5
  from align import align_filter
6
 
7
-
8
  def merge_intervals_with_breaks(time_intervals, errors, max_break=1.5):
9
  print(f"时间区间: {time_intervals}")
10
  print(f"错误: {errors}")
@@ -45,6 +44,33 @@ def findcos_single(k1, k2):
45
  cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
46
  return 100 * (1 - (1 - cosine_similarity) / 2), 0
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def findCosineSimilarity_1(keypoints1, keypoints2):
50
  # transformer = Normalizer().fit(keypoints1)
@@ -91,7 +117,7 @@ def eval(test, standard, tmpdir):
91
  frame_width = int(cap_00.get(cv2.CAP_PROP_FRAME_WIDTH))
92
  frame_height = int(cap_00.get(cv2.CAP_PROP_FRAME_HEIGHT))
93
 
94
- out = cv2.VideoWriter(tmpdir + '/output.mp4', cv2.VideoWriter_fourcc(*'H264'), 5, (frame_width*2, frame_height*2))
95
 
96
  cap_00.set(cv2.CAP_PROP_POS_FRAMES, 0) # 初始化视频从头开始读取
97
  cap_01.set(cv2.CAP_PROP_POS_FRAMES, 0)
@@ -113,12 +139,15 @@ def eval(test, standard, tmpdir):
113
  elif not ret_00 and not ret_01:
114
  comments = 2
115
  break
116
- combined_frame_ori = np.hstack((frame_00, frame_01))
117
-
118
  # 获取视频当前的帧号
119
  frame_id_00 = int(cap_00.get(cv2.CAP_PROP_POS_FRAMES))
120
  frame_id_01 = int(cap_01.get(cv2.CAP_PROP_POS_FRAMES))
121
 
 
 
 
 
122
  # 处理标准视频中的关键点,并绘制关键点连接
123
  if frame_id_00 < min_length:
124
  keypoints_00 = data_00[frame_id_00]["instances"][0]["keypoints"]
@@ -151,7 +180,7 @@ def eval(test, standard, tmpdir):
151
  if frame_id_01 < min_length:
152
  error = []
153
  bigerror = []
154
- keypoints_01 = data_01[frame_id_01]["instances"][0]["keypoints"]
155
 
156
  for (start, end) in connections1:
157
  start = start - 1
@@ -196,7 +225,8 @@ def eval(test, standard, tmpdir):
196
  cv2.circle(frame_01, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)
197
 
198
  # Concatenate the images horizontally to display side by side
199
- combined_frame = np.hstack((frame_00, frame_01))
 
200
 
201
  if frame_id_00 < min_length and frame_id_01 < min_length:
202
  min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"])
@@ -213,7 +243,7 @@ def eval(test, standard, tmpdir):
213
  part = ""
214
 
215
  # 在视频帧上显示检测到的误差部位
216
- cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
217
  start_x = int(frame_width*1.75) + 10 #435 # 起始的 x 坐标
218
  start_y = int(frame_height*0.2) + 50 # 45
219
  line_height = 50 # 每一行文字的高度
@@ -222,7 +252,7 @@ def eval(test, standard, tmpdir):
222
  for i, item in enumerate(list(set(content))):
223
  text = "- " + item
224
  y_position = start_y + i * line_height
225
- cv2.putText(combined_frame, text, (start_x, y_position), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
226
 
227
  # big
228
  if bigerror != []:
@@ -255,29 +285,12 @@ def eval(test, standard, tmpdir):
255
  return sum(scores) / len(scores), final_merged_intervals, comments
256
 
257
  def install():
258
- # if torch.cuda.is_available():
259
- # cu_version = torch.version.cuda
260
- # cu_version = f"cu{cu_version.replace('.', '')}" # Format it as 'cuXX' (e.g., 'cu113')
261
- # else:
262
- # cu_version = "cpu" # Fallback to CPU if no CUDA is available
263
-
264
- # torch_version = torch.__version__.split('+')[0] # Get PyTorch version without build info
265
-
266
- # pip_command = f'pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html'
267
-
268
-
269
- # os.system(pip_command)
270
  import subprocess
271
  subprocess.run(["pip", "uninstall", "-y", "numpy"], check=True)
272
- subprocess.run(["pip", "install", "numpy<2"], check=True)
273
 
274
  os.system('mim install mmengine')
275
- # os.system('mim install "mmcv"')
276
- # os.system('mim install "mmdet"')
277
- # os.system('mim install "mmpose"')
278
- # os.system('pip3 install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html"')
279
- # os.system('pip3 install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html')
280
-
281
  os.system('git clone https://github.com/open-mmlab/mmpose.git')
282
  os.chdir('mmpose')
283
  os.system('pip install -r requirements.txt')
@@ -289,4 +302,5 @@ def install():
289
  os.chdir('mmdetection')
290
  os.system('pip install -v -e .')
291
  os.chdir('../')
292
- # os.system('mim install "mmpose>=1.1.0"')
 
 
4
  from sklearn.preprocessing import Normalizer
5
  from align import align_filter
6
 
 
7
  def merge_intervals_with_breaks(time_intervals, errors, max_break=1.5):
8
  print(f"时间区间: {time_intervals}")
9
  print(f"错误: {errors}")
 
44
  cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
45
  return 100 * (1 - (1 - cosine_similarity) / 2), 0
46
 
47
+ def align_hstack(frame_00, frame_01, keypoints_01=None):
48
+ height_00 = frame_00.shape[0]
49
+ height_01 = frame_01.shape[0]
50
+
51
+ if height_01 != height_00:
52
+ # 计算缩放比例,确保高度与 frame_00 一致
53
+ scale_factor = height_00 / height_01
54
+ new_width = int(frame_01.shape[1] * scale_factor)
55
+
56
+ # 使用 OpenCV 的 resize 函数按比例缩放 frame_01
57
+ frame_01_resized = cv2.resize(frame_01, (new_width, height_00))
58
+ else:
59
+ frame_01_resized = frame_01
60
+
61
+ # 现在可以水平拼接两个数组
62
+ combined_frame_ori = np.hstack((frame_00, frame_01_resized))
63
+ if keypoints_01 == None: return combined_frame_ori, None
64
+
65
+ scale_factor = frame_00.shape[0] / frame_01.shape[0] # 根据高度的缩放比例
66
+ # 对 frame_01 的关键点进行缩放
67
+ keypoints_01_scaled = []
68
+ for point in keypoints_01:
69
+ scaled_point = [point[0] * scale_factor, point[1] * scale_factor] # 仅对 x 和 y 坐标进行缩放
70
+ keypoints_01_scaled.append(scaled_point)
71
+
72
+ return combined_frame_ori, keypoints_01_scaled
73
+
74
 
75
  def findCosineSimilarity_1(keypoints1, keypoints2):
76
  # transformer = Normalizer().fit(keypoints1)
 
117
  frame_width = int(cap_00.get(cv2.CAP_PROP_FRAME_WIDTH))
118
  frame_height = int(cap_00.get(cv2.CAP_PROP_FRAME_HEIGHT))
119
 
120
+ out = cv2.VideoWriter(tmpdir + '/output.mp4', cv2.VideoWriter_fourcc(*'XVID'), 5, (frame_width*2, frame_height*2))
121
 
122
  cap_00.set(cv2.CAP_PROP_POS_FRAMES, 0) # 初始化视频从头开始读取
123
  cap_01.set(cv2.CAP_PROP_POS_FRAMES, 0)
 
139
  elif not ret_00 and not ret_01:
140
  comments = 2
141
  break
142
+ # combined_frame_ori = np.hstack((frame_00, frame_01))
 
143
  # 获取视频当前的帧号
144
  frame_id_00 = int(cap_00.get(cv2.CAP_PROP_POS_FRAMES))
145
  frame_id_01 = int(cap_01.get(cv2.CAP_PROP_POS_FRAMES))
146
 
147
+ if frame_id_01 < min_length:
148
+ combined_frame_ori, keypoints_01_scaled = align_hstack(frame_00, frame_01, data_01[frame_id_01]["instances"][0]["keypoints"])
149
+ else:
150
+ combined_frame_ori, _ = align_hstack(frame_00, frame_01)
151
  # 处理标准视频中的关键点,并绘制关键点连接
152
  if frame_id_00 < min_length:
153
  keypoints_00 = data_00[frame_id_00]["instances"][0]["keypoints"]
 
180
  if frame_id_01 < min_length:
181
  error = []
182
  bigerror = []
183
+ keypoints_01 = keypoints_01_scaled #data_01[frame_id_01]["instances"][0]["keypoints"]
184
 
185
  for (start, end) in connections1:
186
  start = start - 1
 
225
  cv2.circle(frame_01, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)
226
 
227
  # Concatenate the images horizontally to display side by side
228
+ # combined_frame = np.hstack((frame_00, frame_01))
229
+ combined_frame, _ = align_hstack(frame_00, frame_01)
230
 
231
  if frame_id_00 < min_length and frame_id_01 < min_length:
232
  min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"])
 
243
  part = ""
244
 
245
  # 在视频帧上显示检测到的误差部位
246
+ # cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
247
  start_x = int(frame_width*1.75) + 10 #435 # 起始的 x 坐标
248
  start_y = int(frame_height*0.2) + 50 # 45
249
  line_height = 50 # 每一行文字的高度
 
252
  for i, item in enumerate(list(set(content))):
253
  text = "- " + item
254
  y_position = start_y + i * line_height
255
+ # cv2.putText(combined_frame, text, (start_x, y_position), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
256
 
257
  # big
258
  if bigerror != []:
 
285
  return sum(scores) / len(scores), final_merged_intervals, comments
286
 
287
  def install():
 
 
 
 
 
 
 
 
 
 
 
 
288
  import subprocess
289
  subprocess.run(["pip", "uninstall", "-y", "numpy"], check=True)
290
+ subprocess.run(["pip", "install", "numpy<2"]x, check=True)
291
 
292
  os.system('mim install mmengine')
293
+ os.system('mim install mmcv==2.2.0')
 
 
 
 
 
294
  os.system('git clone https://github.com/open-mmlab/mmpose.git')
295
  os.chdir('mmpose')
296
  os.system('pip install -r requirements.txt')
 
302
  os.chdir('mmdetection')
303
  os.system('pip install -v -e .')
304
  os.chdir('../')
305
+
306
+ os.system('apt-get install ffmpeg imagemagick')