File size: 14,203 Bytes
593b9ab
1b44373
593b9ab
 
 
2ebedbd
593b9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186cb99
 
 
e64d541
593b9ab
 
 
186cb99
593b9ab
e64d541
186cb99
593b9ab
 
 
 
 
 
 
 
 
 
 
 
186cb99
593b9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e64d541
593b9ab
e64d541
593b9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6560a99
 
 
5c6d6ba
 
 
 
 
67e7862
5c6d6ba
67e7862
5c6d6ba
67e7862
 
5c6d6ba
3d45521
 
 
 
49377fa
42e737b
542b688
 
 
f14c1b3
3d45521
7a00f61
 
 
752ae1d
7a00f61
3d45521
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import cv2, math
import json, os, torch
import numpy as np
from sklearn.preprocessing import Normalizer
from align import align_filter


def merge_intervals_with_breaks(time_intervals, errors, max_break=1.5):
    print(f"时间区间: {time_intervals}")
    print(f"错误: {errors}")

    if not time_intervals:
        return []

    # Sort intervals based on starting times (not necessary here as input is sorted but good practice)
    sorted_intervals = sorted(zip(time_intervals, errors), key=lambda x: x[0][0])
    
    merged_intervals = []
    current_interval, current_error = sorted_intervals[0]

    for (start, end), error in sorted_intervals[1:]:
        # Check if the current interval error is the same and the break between intervals is <= 1.5 seconds
        if error == current_error and start - current_interval[1] <= max_break:
            # Merge intervals
            current_interval = (round(current_interval[0]), round(max(current_interval[1], end)))
        else:
            # Save the completed interval
            merged_intervals.append(((round(current_interval[0]), round(current_interval[1])), current_error))
            # merged_intervals.append((current_interval, current_error))
            # Start a new interval
            current_interval, current_error = (round(start), round(end)), error

    # Add the last interval
    merged_intervals.append((current_interval, current_error))
    
    return merged_intervals
def findcos_single(k1, k2):
    u1 = np.array(k1).reshape(-1, 1)
    u2 = np.array(k2).reshape(-1, 1)
    source_representation, test_representation = u1, u2
    a = np.matmul(np.transpose(source_representation), test_representation)
    b = np.sum(np.multiply(source_representation, source_representation))
    c = np.sum(np.multiply(test_representation, test_representation))
    # return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
    cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
    return 100 * (1 - (1 - cosine_similarity) / 2), 0


def findCosineSimilarity_1(keypoints1, keypoints2):
    # transformer = Normalizer().fit(keypoints1)  
    # keypoints1 = transformer.transform(keypoints1)
    user1 = np.concatenate((keypoints1[5:13], keypoints1[91:133]), axis=0).reshape(-1, 1)

    # transformer = Normalizer().fit(keypoints2)  
    # keypoints2 = transformer.transform(keypoints2)
    user2 = np.concatenate((keypoints2[5:13], keypoints2[91:133]), axis=0).reshape(-1, 1)

    ####ZIYU
    source_representation, test_representation = user1, user2
    a = np.matmul(np.transpose(source_representation), test_representation)
    b = np.sum(np.multiply(source_representation, source_representation))
    c = np.sum(np.multiply(test_representation, test_representation))
    # return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
    cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
    return 100 * (1 - (1 - cosine_similarity) / 2), 0

def load_json(path):
    with open(path, 'r') as file:
        return json.load(file)

def eval(test, standard, tmpdir):
    test_p = tmpdir + "/user.mp4"
    standard_p = tmpdir + "/standard.mp4"
    os.system('python inferencer_demo.py ' + test_p + ' --pred-out-dir ' + tmpdir) # produce user.json

    scores = []

    align_filter(tmpdir + '/standard', tmpdir + '/user', tmpdir) # 帧对齐 produce aligned vedios

    data_00 = load_json(tmpdir + '/standard.json') 
    data_01 = load_json(tmpdir + '/user.json')
    cap_00 = cv2.VideoCapture(standard_p)
    cap_01 = cv2.VideoCapture(test_p)
    # Define keypoint connections for both videos (example indices, you'll need to customize)
    connections1 = [(9,11), (7,9), (6,7), (6,8), (8,10), (7,13), (6,12), (12,13)]
    connections2 = [(130,133), (126,129), (122,125), (118,121), (114,117), (93,96), (97,100), (101,104), (105,108), (109,112)]

    # Determine the minimum length of JSON data to use
    min_length = min(len(data_00), len(data_01))

    frame_width = int(cap_00.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap_00.get(cv2.CAP_PROP_FRAME_HEIGHT))

    out = cv2.VideoWriter(tmpdir + '/output.mp4', cv2.VideoWriter_fourcc(*'H264'), 5, (frame_width*2, frame_height*2))

    cap_00.set(cv2.CAP_PROP_POS_FRAMES, 0) # 初始化视频从头开始读取
    cap_01.set(cv2.CAP_PROP_POS_FRAMES, 0)
    comments = -1
    error_dict = {}
    cnt = 0

    line_width = 1 if frame_width // 300 == 0 else frame_width // 300 
    # 开始逐帧处理两个视频
    while True:
        ret_00, frame_00 = cap_00.read() # 逐帧读取标准视频和用户视频的当前帧
        ret_01, frame_01 = cap_01.read()
        if not ret_00 and ret_01:
            comments = 0  #.append("请尝试加快手势的完成速度,并确保每个动作都清晰可见。")
            break  # Stop if either video runs out of frames
        elif ret_00 and not ret_01:
            comments = 1 #.append("请尝试放慢手势的完成速度,确保每个动作都清晰可见。")
            break  # Stop if either video runs out of frames
        elif not ret_00 and not ret_01:
            comments = 2
            break
        combined_frame_ori = np.hstack((frame_00, frame_01))

        # 获取视频当前的帧号
        frame_id_00 = int(cap_00.get(cv2.CAP_PROP_POS_FRAMES))
        frame_id_01 = int(cap_01.get(cv2.CAP_PROP_POS_FRAMES))

        # 处理标准视频中的关键点,并绘制关键点连接
        if frame_id_00 < min_length:
            keypoints_00 = data_00[frame_id_00]["instances"][0]["keypoints"]

            for (start, end) in connections1:
                start = start - 1
                end = end - 1
                if start < len(keypoints_00) and end < len(keypoints_00):
                    start_point = (int(keypoints_00[start][0]), int(keypoints_00[start][1]))
                    end_point = (int(keypoints_00[end][0]), int(keypoints_00[end][1]))
                    cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width)  # (BGR) Blue line
            for (start, end) in connections2:
                start = start - 1
                end = end - 1
                for i in range(start, end):
                    if i < len(keypoints_00) and i + 1 < len(keypoints_00):
                        start_point = (int(keypoints_00[i][0]), int(keypoints_00[i][1]))
                        end_point = (int(keypoints_00[i + 1][0]), int(keypoints_00[i + 1][1]))
                        cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width)  # Blue line
            
            # 将部分关键点保存并绘制圆点,标记关键位置
            keypoints_00_ori = keypoints_00
            keypoints_00 = keypoints_00[5:13] + keypoints_00[91:133]

            for point in keypoints_00:
                cv2.circle(frame_00, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)

            
        # 处理用户视频中的关键点,并进行误差分析
        if frame_id_01 < min_length:
            error = []
            bigerror = []
            keypoints_01 = data_01[frame_id_01]["instances"][0]["keypoints"]

            for (start, end) in connections1:
                start = start - 1
                end = end - 1
                if start < len(keypoints_01) and end < len(keypoints_01):
                    start_point = (int(keypoints_01[start][0]), int(keypoints_01[start][1]))
                    end_point = (int(keypoints_01[end][0]), int(keypoints_01[end][1]))
                    cur_score = findcos_single([[int(keypoints_01[start][0]), int(keypoints_01[start][1])], [int(keypoints_01[end][0]), int(keypoints_01[end][1])]], [[int(keypoints_00_ori[start][0]), int(keypoints_00_ori[start][1])], [int(keypoints_00_ori[end][0]), int(keypoints_00_ori[end][1])]])

                     # 如果当前相似度小于 99.3,认为有误差,并记录下来
                    if float(cur_score[0]) < 98.8 and start != 5:
                        error.append(start)
                        cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2)  # Red line
                        # 如果相似度低于 98,记录为大误差
                        if float(cur_score[0]) < 97.8:
                            bigerror.append(start)
                    else:
                        cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width)  # Blue line

            for (start, end) in connections2:
                start = start - 1
                end = end - 1
                for i in range(start, end):
                    if i < len(keypoints_01) and i + 1 < len(keypoints_01):
                        start_point = (int(keypoints_01[i][0]), int(keypoints_01[i][1]))
                        end_point = (int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1]))

                        cur_score = findcos_single([[int(keypoints_01[i][0]), int(keypoints_01[i][1])], [int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1])]], [[int(keypoints_00_ori[i][0]), int(keypoints_00_ori[i][1])], [int(keypoints_00_ori[i + 1][0]), int(keypoints_00_ori[i + 1][1])]])

                        if float(cur_score[0]) < 98.8:
                            error.append(start)
                            cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2)  # Red line
                            if float(cur_score[0]) < 97.8:
                                bigerror.append(start)
                        else:
                            cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width)  # Blue line
            
            # 将用户视频的关键点绘制为圆点
            keypoints_01 = keypoints_01[5:13] + keypoints_01[91:133]

            for point in keypoints_01:
                cv2.circle(frame_01, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)
            
        # Concatenate the images horizontally to display side by side
        combined_frame = np.hstack((frame_00, frame_01))

        if frame_id_00 < min_length and frame_id_01 < min_length:
            min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"])

         # 如果存在误差,将误差部分对应的人体部位加入内容列表
        if error != []:
            # print(error)
            content = []
            for i in error:
                if i in [5,7]: content.append('Left Arm')
                if i in [6,8]: content.append('Right Arm')
                if i > 90 and i < 112: content.append('Left Hand')
                if i >= 112: content.append('Right Hand')
            part = ""

            # 在视频帧上显示检测到的误差部位
            cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
            start_x = int(frame_width*1.75) + 10   #435 # 起始的 x 坐标
            start_y = int(frame_height*0.2) + 50 # 45
            line_height = 50 # 每一行文字的高度

            # 将每一个部位的内容绘制到帧上
            for i, item in enumerate(list(set(content))):    
                text = "- " + item
                y_position = start_y + i * line_height
                cv2.putText(combined_frame, text, (start_x, y_position), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)

        # big
        if bigerror != []:
            bigcontent = []
            for i in bigerror:
                if i in [5,7]: bigcontent.append('Left Arm')
                if i in [6,8]: bigcontent.append('Right Arm')
                if i > 90 and i < 112: bigcontent.append('Left Hand')
                if i >= 112: bigcontent.append('Right Hand')
            
            # 记录当前帧的严重误差部位,存入 error_dict 中
            error_dict[cnt] = list(set(bigcontent))

        cnt += 1
        combined_frame = np.vstack((combined_frame_ori, combined_frame))
        out.write(combined_frame)
        scores.append(float(min_cos)) # 记录每一帧的相似度得分

    fps = 5  # Frames per second
    frame_numbers = list(error_dict.keys())  # List of frame numbers 获取含有严重误差的帧号列表
    time_intervals = [(frame / fps, (frame + 1) / fps) for frame in frame_numbers] # 将帧号转换为时间区间(秒)
    errors = [error_dict[frame] for frame in frame_numbers] # 每一帧对应的严重误差部位
    final_merged_intervals = merge_intervals_with_breaks(time_intervals, errors) # 合并相邻或相近的时间区间,并记录对应的误差部位
    out.release()

    # 返回三个结果:
    # 1. scores 的平均值,作为整体手势相似度的评分
    # 2. final_merged_intervals,合并后的误差时间区间及其对应的误差信息
    # 3. comments,用于给用户的速度建议(加快或放慢手势)
    return sum(scores) / len(scores), final_merged_intervals, comments

def install():
    # if torch.cuda.is_available():
    #     cu_version = torch.version.cuda
    #     cu_version = f"cu{cu_version.replace('.', '')}"  # Format it as 'cuXX' (e.g., 'cu113')
    # else:
    #     cu_version = "cpu"  # Fallback to CPU if no CUDA is available

    # torch_version = torch.__version__.split('+')[0]  # Get PyTorch version without build info

    # pip_command = f'pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html'


    # os.system(pip_command)
    import subprocess
    subprocess.run(["pip", "uninstall", "-y", "numpy"], check=True)
    subprocess.run(["pip", "install", "numpy<2"], check=True)
    
    os.system('mim install mmengine')
    # os.system('mim install "mmcv"')
    # os.system('mim install "mmdet"')
    # os.system('mim install "mmpose"')
    # os.system('pip3 install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html"')
    # os.system('pip3 install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html')
    
    os.system('git clone https://github.com/open-mmlab/mmpose.git')
    os.chdir('mmpose')
    os.system('pip install -r requirements.txt')
    os.system('pip install -v -e .')
    os.chdir('../')
    # os.system('mim install "mmpose>=1.1.0"')