Spaces:
Running
on
A10G
Running
on
A10G
File size: 14,186 Bytes
593b9ab 1b44373 593b9ab bf45168 593b9ab 186cb99 e64d541 593b9ab 186cb99 593b9ab e64d541 186cb99 593b9ab 186cb99 593b9ab e64d541 593b9ab e64d541 593b9ab 6560a99 5c6d6ba 67e7862 5c6d6ba 67e7862 5c6d6ba 67e7862 5c6d6ba 3d45521 542b688 42e737b 542b688 3d45521 7a00f61 752ae1d 7a00f61 3d45521 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
import cv2, math
import json, os, torch
import numpy as np
from sklearn.preprocessing import Normalizer
from align import align_filter
import sys
sys.path.append("./")
sys.path.append("./mmpose")
sys.path.append("./mmpose/mmpose")
def merge_intervals_with_breaks(time_intervals, errors, max_break=1.5):
print(f"时间区间: {time_intervals}")
print(f"错误: {errors}")
if not time_intervals:
return []
# Sort intervals based on starting times (not necessary here as input is sorted but good practice)
sorted_intervals = sorted(zip(time_intervals, errors), key=lambda x: x[0][0])
merged_intervals = []
current_interval, current_error = sorted_intervals[0]
for (start, end), error in sorted_intervals[1:]:
# Check if the current interval error is the same and the break between intervals is <= 1.5 seconds
if error == current_error and start - current_interval[1] <= max_break:
# Merge intervals
current_interval = (round(current_interval[0]), round(max(current_interval[1], end)))
else:
# Save the completed interval
merged_intervals.append(((round(current_interval[0]), round(current_interval[1])), current_error))
# merged_intervals.append((current_interval, current_error))
# Start a new interval
current_interval, current_error = (round(start), round(end)), error
# Add the last interval
merged_intervals.append((current_interval, current_error))
return merged_intervals
def findcos_single(k1, k2):
u1 = np.array(k1).reshape(-1, 1)
u2 = np.array(k2).reshape(-1, 1)
source_representation, test_representation = u1, u2
a = np.matmul(np.transpose(source_representation), test_representation)
b = np.sum(np.multiply(source_representation, source_representation))
c = np.sum(np.multiply(test_representation, test_representation))
# return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
return 100 * (1 - (1 - cosine_similarity) / 2), 0
def findCosineSimilarity_1(keypoints1, keypoints2):
# transformer = Normalizer().fit(keypoints1)
# keypoints1 = transformer.transform(keypoints1)
user1 = np.concatenate((keypoints1[5:13], keypoints1[91:133]), axis=0).reshape(-1, 1)
# transformer = Normalizer().fit(keypoints2)
# keypoints2 = transformer.transform(keypoints2)
user2 = np.concatenate((keypoints2[5:13], keypoints2[91:133]), axis=0).reshape(-1, 1)
####ZIYU
source_representation, test_representation = user1, user2
a = np.matmul(np.transpose(source_representation), test_representation)
b = np.sum(np.multiply(source_representation, source_representation))
c = np.sum(np.multiply(test_representation, test_representation))
# return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
return 100 * (1 - (1 - cosine_similarity) / 2), 0
def load_json(path):
with open(path, 'r') as file:
return json.load(file)
def eval(test, standard, tmpdir):
test_p = tmpdir + "/user.mp4"
standard_p = tmpdir + "/standard.mp4"
os.system('python inferencer_demo.py ' + test_p + ' --pred-out-dir ' + tmpdir) # produce user.json
scores = []
align_filter(tmpdir + '/standard', tmpdir + '/user', tmpdir) # 帧对齐 produce aligned vedios
data_00 = load_json(tmpdir + '/standard.json')
data_01 = load_json(tmpdir + '/user.json')
cap_00 = cv2.VideoCapture(standard_p)
cap_01 = cv2.VideoCapture(test_p)
# Define keypoint connections for both videos (example indices, you'll need to customize)
connections1 = [(9,11), (7,9), (6,7), (6,8), (8,10), (7,13), (6,12), (12,13)]
connections2 = [(130,133), (126,129), (122,125), (118,121), (114,117), (93,96), (97,100), (101,104), (105,108), (109,112)]
# Determine the minimum length of JSON data to use
min_length = min(len(data_00), len(data_01))
frame_width = int(cap_00.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap_00.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(tmpdir + '/output.mp4', cv2.VideoWriter_fourcc(*'H264'), 5, (frame_width*2, frame_height*2))
cap_00.set(cv2.CAP_PROP_POS_FRAMES, 0) # 初始化视频从头开始读取
cap_01.set(cv2.CAP_PROP_POS_FRAMES, 0)
comments = -1
error_dict = {}
cnt = 0
line_width = 1 if frame_width // 300 == 0 else frame_width // 300
# 开始逐帧处理两个视频
while True:
ret_00, frame_00 = cap_00.read() # 逐帧读取标准视频和用户视频的当前帧
ret_01, frame_01 = cap_01.read()
if not ret_00 and ret_01:
comments = 0 #.append("请尝试加快手势的完成速度,并确保每个动作都清晰可见。")
break # Stop if either video runs out of frames
elif ret_00 and not ret_01:
comments = 1 #.append("请尝试放慢手势的完成速度,确保每个动作都清晰可见。")
break # Stop if either video runs out of frames
elif not ret_00 and not ret_01:
comments = 2
break
combined_frame_ori = np.hstack((frame_00, frame_01))
# 获取视频当前的帧号
frame_id_00 = int(cap_00.get(cv2.CAP_PROP_POS_FRAMES))
frame_id_01 = int(cap_01.get(cv2.CAP_PROP_POS_FRAMES))
# 处理标准视频中的关键点,并绘制关键点连接
if frame_id_00 < min_length:
keypoints_00 = data_00[frame_id_00]["instances"][0]["keypoints"]
for (start, end) in connections1:
start = start - 1
end = end - 1
if start < len(keypoints_00) and end < len(keypoints_00):
start_point = (int(keypoints_00[start][0]), int(keypoints_00[start][1]))
end_point = (int(keypoints_00[end][0]), int(keypoints_00[end][1]))
cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width) # (BGR) Blue line
for (start, end) in connections2:
start = start - 1
end = end - 1
for i in range(start, end):
if i < len(keypoints_00) and i + 1 < len(keypoints_00):
start_point = (int(keypoints_00[i][0]), int(keypoints_00[i][1]))
end_point = (int(keypoints_00[i + 1][0]), int(keypoints_00[i + 1][1]))
cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width) # Blue line
# 将部分关键点保存并绘制圆点,标记关键位置
keypoints_00_ori = keypoints_00
keypoints_00 = keypoints_00[5:13] + keypoints_00[91:133]
for point in keypoints_00:
cv2.circle(frame_00, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)
# 处理用户视频中的关键点,并进行误差分析
if frame_id_01 < min_length:
error = []
bigerror = []
keypoints_01 = data_01[frame_id_01]["instances"][0]["keypoints"]
for (start, end) in connections1:
start = start - 1
end = end - 1
if start < len(keypoints_01) and end < len(keypoints_01):
start_point = (int(keypoints_01[start][0]), int(keypoints_01[start][1]))
end_point = (int(keypoints_01[end][0]), int(keypoints_01[end][1]))
cur_score = findcos_single([[int(keypoints_01[start][0]), int(keypoints_01[start][1])], [int(keypoints_01[end][0]), int(keypoints_01[end][1])]], [[int(keypoints_00_ori[start][0]), int(keypoints_00_ori[start][1])], [int(keypoints_00_ori[end][0]), int(keypoints_00_ori[end][1])]])
# 如果当前相似度小于 99.3,认为有误差,并记录下来
if float(cur_score[0]) < 98.8 and start != 5:
error.append(start)
cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2) # Red line
# 如果相似度低于 98,记录为大误差
if float(cur_score[0]) < 97.8:
bigerror.append(start)
else:
cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line
for (start, end) in connections2:
start = start - 1
end = end - 1
for i in range(start, end):
if i < len(keypoints_01) and i + 1 < len(keypoints_01):
start_point = (int(keypoints_01[i][0]), int(keypoints_01[i][1]))
end_point = (int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1]))
cur_score = findcos_single([[int(keypoints_01[i][0]), int(keypoints_01[i][1])], [int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1])]], [[int(keypoints_00_ori[i][0]), int(keypoints_00_ori[i][1])], [int(keypoints_00_ori[i + 1][0]), int(keypoints_00_ori[i + 1][1])]])
if float(cur_score[0]) < 98.8:
error.append(start)
cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2) # Red line
if float(cur_score[0]) < 97.8:
bigerror.append(start)
else:
cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line
# 将用户视频的关键点绘制为圆点
keypoints_01 = keypoints_01[5:13] + keypoints_01[91:133]
for point in keypoints_01:
cv2.circle(frame_01, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)
# Concatenate the images horizontally to display side by side
combined_frame = np.hstack((frame_00, frame_01))
if frame_id_00 < min_length and frame_id_01 < min_length:
min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"])
# 如果存在误差,将误差部分对应的人体部位加入内容列表
if error != []:
# print(error)
content = []
for i in error:
if i in [5,7]: content.append('Left Arm')
if i in [6,8]: content.append('Right Arm')
if i > 90 and i < 112: content.append('Left Hand')
if i >= 112: content.append('Right Hand')
part = ""
# 在视频帧上显示检测到的误差部位
cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
start_x = int(frame_width*1.75) + 10 #435 # 起始的 x 坐标
start_y = int(frame_height*0.2) + 50 # 45
line_height = 50 # 每一行文字的高度
# 将每一个部位的内容绘制到帧上
for i, item in enumerate(list(set(content))):
text = "- " + item
y_position = start_y + i * line_height
cv2.putText(combined_frame, text, (start_x, y_position), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
# big
if bigerror != []:
bigcontent = []
for i in bigerror:
if i in [5,7]: bigcontent.append('Left Arm')
if i in [6,8]: bigcontent.append('Right Arm')
if i > 90 and i < 112: bigcontent.append('Left Hand')
if i >= 112: bigcontent.append('Right Hand')
# 记录当前帧的严重误差部位,存入 error_dict 中
error_dict[cnt] = list(set(bigcontent))
cnt += 1
combined_frame = np.vstack((combined_frame_ori, combined_frame))
out.write(combined_frame)
scores.append(float(min_cos)) # 记录每一帧的相似度得分
fps = 5 # Frames per second
frame_numbers = list(error_dict.keys()) # List of frame numbers 获取含有严重误差的帧号列表
time_intervals = [(frame / fps, (frame + 1) / fps) for frame in frame_numbers] # 将帧号转换为时间区间(秒)
errors = [error_dict[frame] for frame in frame_numbers] # 每一帧对应的严重误差部位
final_merged_intervals = merge_intervals_with_breaks(time_intervals, errors) # 合并相邻或相近的时间区间,并记录对应的误差部位
out.release()
# 返回三个结果:
# 1. scores 的平均值,作为整体手势相似度的评分
# 2. final_merged_intervals,合并后的误差时间区间及其对应的误差信息
# 3. comments,用于给用户的速度建议(加快或放慢手势)
return sum(scores) / len(scores), final_merged_intervals, comments
def install():
# if torch.cuda.is_available():
# cu_version = torch.version.cuda
# cu_version = f"cu{cu_version.replace('.', '')}" # Format it as 'cuXX' (e.g., 'cu113')
# else:
# cu_version = "cpu" # Fallback to CPU if no CUDA is available
# torch_version = torch.__version__.split('+')[0] # Get PyTorch version without build info
# pip_command = f'pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html'
# os.system(pip_command)
import subprocess
subprocess.run(["pip", "uninstall", "-y", "numpy"], check=True)
subprocess.run(["pip", "install", "numpy<2"], check=True)
# os.system('mim install mmengine')
# os.system('mim install "mmcv"')
# os.system('mim install "mmdet"')
# os.system('mim install "mmpose"')
# os.system('pip3 install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html"')
os.system('git clone https://github.com/open-mmlab/mmpose.git')
os.chdir('mmpose')
os.system('pip install -r requirements.txt')
os.system('pip install -v -e .')
os.chdir('../')
# os.system('mim install "mmpose>=1.1.0"')
|