SignLanguage / evaluate.py
ZiyuG's picture
Update evaluate.py
6f0f040 verified
raw
history blame
14 kB
import cv2, math
import json, os, torch
import numpy as np
from sklearn.preprocessing import Normalizer
from align import align_filter
def merge_intervals_with_breaks(time_intervals, errors, max_break=1.5):
print(f"时间区间: {time_intervals}")
print(f"错误: {errors}")
if not time_intervals:
return []
# Sort intervals based on starting times (not necessary here as input is sorted but good practice)
sorted_intervals = sorted(zip(time_intervals, errors), key=lambda x: x[0][0])
merged_intervals = []
current_interval, current_error = sorted_intervals[0]
for (start, end), error in sorted_intervals[1:]:
# Check if the current interval error is the same and the break between intervals is <= 1.5 seconds
if error == current_error and start - current_interval[1] <= max_break:
# Merge intervals
current_interval = (round(current_interval[0]), round(max(current_interval[1], end)))
else:
# Save the completed interval
merged_intervals.append(((round(current_interval[0]), round(current_interval[1])), current_error))
# merged_intervals.append((current_interval, current_error))
# Start a new interval
current_interval, current_error = (round(start), round(end)), error
# Add the last interval
merged_intervals.append((current_interval, current_error))
return merged_intervals
def findcos_single(k1, k2):
u1 = np.array(k1).reshape(-1, 1)
u2 = np.array(k2).reshape(-1, 1)
source_representation, test_representation = u1, u2
a = np.matmul(np.transpose(source_representation), test_representation)
b = np.sum(np.multiply(source_representation, source_representation))
c = np.sum(np.multiply(test_representation, test_representation))
# return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
return 100 * (1 - (1 - cosine_similarity) / 2), 0
def findCosineSimilarity_1(keypoints1, keypoints2):
# transformer = Normalizer().fit(keypoints1)
# keypoints1 = transformer.transform(keypoints1)
user1 = np.concatenate((keypoints1[5:13], keypoints1[91:133]), axis=0).reshape(-1, 1)
# transformer = Normalizer().fit(keypoints2)
# keypoints2 = transformer.transform(keypoints2)
user2 = np.concatenate((keypoints2[5:13], keypoints2[91:133]), axis=0).reshape(-1, 1)
####ZIYU
source_representation, test_representation = user1, user2
a = np.matmul(np.transpose(source_representation), test_representation)
b = np.sum(np.multiply(source_representation, source_representation))
c = np.sum(np.multiply(test_representation, test_representation))
# return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
return 100 * (1 - (1 - cosine_similarity) / 2), 0
def load_json(path):
with open(path, 'r') as file:
return json.load(file)
def eval(test, standard, tmpdir):
test_p = tmpdir + "/user.mp4"
standard_p = tmpdir + "/standard.mp4"
os.system('python inferencer_demo.py ' + test_p + ' --pred-out-dir ' + tmpdir) # produce user.json
scores = []
align_filter(tmpdir + '/standard', tmpdir + '/user', tmpdir) # 帧对齐 produce aligned vedios
data_00 = load_json(tmpdir + '/standard.json')
data_01 = load_json(tmpdir + '/user.json')
cap_00 = cv2.VideoCapture(standard_p)
cap_01 = cv2.VideoCapture(test_p)
# Define keypoint connections for both videos (example indices, you'll need to customize)
connections1 = [(9,11), (7,9), (6,7), (6,8), (8,10), (7,13), (6,12), (12,13)]
connections2 = [(130,133), (126,129), (122,125), (118,121), (114,117), (93,96), (97,100), (101,104), (105,108), (109,112)]
# Determine the minimum length of JSON data to use
min_length = min(len(data_00), len(data_01))
frame_width = int(cap_00.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap_00.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(tmpdir + '/output.mp4', cv2.VideoWriter_fourcc(*'H264'), 5, (frame_width*2, frame_height*2))
cap_00.set(cv2.CAP_PROP_POS_FRAMES, 0) # 初始化视频从头开始读取
cap_01.set(cv2.CAP_PROP_POS_FRAMES, 0)
comments = -1
error_dict = {}
cnt = 0
line_width = 1 if frame_width // 300 == 0 else frame_width // 300
# 开始逐帧处理两个视频
while True:
ret_00, frame_00 = cap_00.read() # 逐帧读取标准视频和用户视频的当前帧
ret_01, frame_01 = cap_01.read()
if not ret_00 and ret_01:
comments = 0 #.append("请尝试加快手势的完成速度,并确保每个动作都清晰可见。")
break # Stop if either video runs out of frames
elif ret_00 and not ret_01:
comments = 1 #.append("请尝试放慢手势的完成速度,确保每个动作都清晰可见。")
break # Stop if either video runs out of frames
elif not ret_00 and not ret_01:
comments = 2
break
combined_frame_ori = np.hstack((frame_00, frame_01))
# 获取视频当前的帧号
frame_id_00 = int(cap_00.get(cv2.CAP_PROP_POS_FRAMES))
frame_id_01 = int(cap_01.get(cv2.CAP_PROP_POS_FRAMES))
# 处理标准视频中的关键点,并绘制关键点连接
if frame_id_00 < min_length:
keypoints_00 = data_00[frame_id_00]["instances"][0]["keypoints"]
for (start, end) in connections1:
start = start - 1
end = end - 1
if start < len(keypoints_00) and end < len(keypoints_00):
start_point = (int(keypoints_00[start][0]), int(keypoints_00[start][1]))
end_point = (int(keypoints_00[end][0]), int(keypoints_00[end][1]))
cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width) # (BGR) Blue line
for (start, end) in connections2:
start = start - 1
end = end - 1
for i in range(start, end):
if i < len(keypoints_00) and i + 1 < len(keypoints_00):
start_point = (int(keypoints_00[i][0]), int(keypoints_00[i][1]))
end_point = (int(keypoints_00[i + 1][0]), int(keypoints_00[i + 1][1]))
cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width) # Blue line
# 将部分关键点保存并绘制圆点,标记关键位置
keypoints_00_ori = keypoints_00
keypoints_00 = keypoints_00[5:13] + keypoints_00[91:133]
for point in keypoints_00:
cv2.circle(frame_00, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)
# 处理用户视频中的关键点,并进行误差分析
if frame_id_01 < min_length:
error = []
bigerror = []
keypoints_01 = data_01[frame_id_01]["instances"][0]["keypoints"]
for (start, end) in connections1:
start = start - 1
end = end - 1
if start < len(keypoints_01) and end < len(keypoints_01):
start_point = (int(keypoints_01[start][0]), int(keypoints_01[start][1]))
end_point = (int(keypoints_01[end][0]), int(keypoints_01[end][1]))
cur_score = findcos_single([[int(keypoints_01[start][0]), int(keypoints_01[start][1])], [int(keypoints_01[end][0]), int(keypoints_01[end][1])]], [[int(keypoints_00_ori[start][0]), int(keypoints_00_ori[start][1])], [int(keypoints_00_ori[end][0]), int(keypoints_00_ori[end][1])]])
# 如果当前相似度小于 99.3,认为有误差,并记录下来
if float(cur_score[0]) < 98.8 and start != 5:
error.append(start)
cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2) # Red line
# 如果相似度低于 98,记录为大误差
if float(cur_score[0]) < 97.8:
bigerror.append(start)
else:
cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line
for (start, end) in connections2:
start = start - 1
end = end - 1
for i in range(start, end):
if i < len(keypoints_01) and i + 1 < len(keypoints_01):
start_point = (int(keypoints_01[i][0]), int(keypoints_01[i][1]))
end_point = (int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1]))
cur_score = findcos_single([[int(keypoints_01[i][0]), int(keypoints_01[i][1])], [int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1])]], [[int(keypoints_00_ori[i][0]), int(keypoints_00_ori[i][1])], [int(keypoints_00_ori[i + 1][0]), int(keypoints_00_ori[i + 1][1])]])
if float(cur_score[0]) < 98.8:
error.append(start)
cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2) # Red line
if float(cur_score[0]) < 97.8:
bigerror.append(start)
else:
cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width) # Blue line
# 将用户视频的关键点绘制为圆点
keypoints_01 = keypoints_01[5:13] + keypoints_01[91:133]
for point in keypoints_01:
cv2.circle(frame_01, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)
# Concatenate the images horizontally to display side by side
combined_frame = np.hstack((frame_00, frame_01))
if frame_id_00 < min_length and frame_id_01 < min_length:
min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"])
# 如果存在误差,将误差部分对应的人体部位加入内容列表
if error != []:
# print(error)
content = []
for i in error:
if i in [5,7]: content.append('Left Arm')
if i in [6,8]: content.append('Right Arm')
if i > 90 and i < 112: content.append('Left Hand')
if i >= 112: content.append('Right Hand')
part = ""
# 在视频帧上显示检测到的误差部位
cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
start_x = int(frame_width*1.75) + 10 #435 # 起始的 x 坐标
start_y = int(frame_height*0.2) + 50 # 45
line_height = 50 # 每一行文字的高度
# 将每一个部位的内容绘制到帧上
for i, item in enumerate(list(set(content))):
text = "- " + item
y_position = start_y + i * line_height
cv2.putText(combined_frame, text, (start_x, y_position), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
# big
if bigerror != []:
bigcontent = []
for i in bigerror:
if i in [5,7]: bigcontent.append('Left Arm')
if i in [6,8]: bigcontent.append('Right Arm')
if i > 90 and i < 112: bigcontent.append('Left Hand')
if i >= 112: bigcontent.append('Right Hand')
# 记录当前帧的严重误差部位,存入 error_dict 中
error_dict[cnt] = list(set(bigcontent))
cnt += 1
combined_frame = np.vstack((combined_frame_ori, combined_frame))
out.write(combined_frame)
scores.append(float(min_cos)) # 记录每一帧的相似度得分
fps = 5 # Frames per second
frame_numbers = list(error_dict.keys()) # List of frame numbers 获取含有严重误差的帧号列表
time_intervals = [(frame / fps, (frame + 1) / fps) for frame in frame_numbers] # 将帧号转换为时间区间(秒)
errors = [error_dict[frame] for frame in frame_numbers] # 每一帧对应的严重误差部位
final_merged_intervals = merge_intervals_with_breaks(time_intervals, errors) # 合并相邻或相近的时间区间,并记录对应的误差部位
out.release()
# 返回三个结果:
# 1. scores 的平均值,作为整体手势相似度的评分
# 2. final_merged_intervals,合并后的误差时间区间及其对应的误差信息
# 3. comments,用于给用户的速度建议(加快或放慢手势)
return sum(scores) / len(scores), final_merged_intervals, comments
def install():
# if torch.cuda.is_available():
# cu_version = torch.version.cuda
# cu_version = f"cu{cu_version.replace('.', '')}" # Format it as 'cuXX' (e.g., 'cu113')
# else:
# cu_version = "cpu" # Fallback to CPU if no CUDA is available
# torch_version = torch.__version__.split('+')[0] # Get PyTorch version without build info
# pip_command = f'pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html'
# os.system(pip_command)
os.system('mim install mmengine')
os.system('pip3 install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html"')
os.system('mim install "mmdet"')
# os.system('git clone https://github.com/open-mmlab/mmpose.git')
# os.chdir('mmpose')
# os.system('pip install -r requirements.txt')
# os.chdir('../')
import subprocess
subprocess.run(["pip", "uninstall", "-y", "numpy"], check=True)
subprocess.run(["pip", "install", "numpy<2"], check=True)
# os.system('mim install "mmpose"')
os.system('mim install "mmpose>=1.1.0"')