ermu2001 commited on
Commit
d8d3000
·
1 Parent(s): d0939b5

Upload DATA/test_landmark.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. DATA/test_landmark.py +158 -0
DATA/test_landmark.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import cv2
3
+ # import time
4
+ # import glob
5
+ # import argparse
6
+ # import scipy
7
+ # import numpy as np
8
+ # from PIL import Image
9
+ # import torch
10
+ # from tqdm import tqdm
11
+ # from itertools import cycle
12
+
13
+ # from extract_kp_videos_safe import KeypointExtractor
14
+ # import numpy as np
15
+ # from PIL import Image
16
+ # with torch.no_grad():
17
+ # img_np =cv2.imread('Strawberry Monster.png')
18
+ # predictor = KeypointExtractor('cuda')
19
+ # dets = predictor.det_net.detect_faces(img_np, 0.97)
20
+ # if len(dets) == 0:
21
+ # detect = False
22
+ # else:
23
+ # print("success")
24
+ # import os
25
+ # import cv2
26
+ # import torch
27
+ # from tqdm import tqdm
28
+ # from extract_kp_videos_safe import KeypointExtractor
29
+
30
+ # # 创建 KeypointExtractor 实例
31
+
32
+ # # 设置文件夹路径
33
+ # folder_path = 'control_inversion'
34
+ # landmark_detect_false=0
35
+ # landmark_detect_success=0
36
+ # # 遍历文件夹中的图像文件
37
+ # for filename in tqdm(os.listdir(path)):
38
+ # if filename.endswith('.png') or filename.endswith('.jpg'):
39
+ # # 读取图像
40
+ # image_path = os.path.join(folder_path, filename)
41
+ # img_np = cv2.imread(image_path)
42
+
43
+ # # 进行人脸检测和关键点提取
44
+ # with torch.no_grad():
45
+ # predictor = KeypointExtractor('cuda')
46
+ # dets = predictor.det_net.detect_faces(img_np, 0.97)
47
+ # if len(dets) == 0:
48
+ # landmark_detect_false += 1
49
+ # else:
50
+ # landmark_detect_success += 1
51
+
52
+ # detect_rate = landmark_detect_success/(landmark_detect_success+landmark_detect_false)
53
+ # print(detect_rate)
54
+ # import os
55
+ # import cv2
56
+ # import torch
57
+ # from tqdm import tqdm
58
+ # from extract_kp_videos_safe import KeypointExtractor
59
+
60
+ # # 设置文件夹路径
61
+ # folder_path = 'prompts'
62
+
63
+ # # 初始化成功和失败的计数
64
+ # total_landmark_detect_success = 0
65
+ # total_landmark_detect_false = 0
66
+
67
+ # # 遍历文件夹中的 txt 文件
68
+ # for txt_filename in os.listdir(folder_path):
69
+ # if txt_filename.endswith('.txt'):
70
+ # txt_file_path = os.path.join(folder_path, txt_filename)
71
+
72
+ # # 读取 txt 文件中的图片列表
73
+ # with open(txt_file_path, 'r') as file:
74
+ # image_list = file.read().splitlines()
75
+
76
+ # landmark_detect_success = 0
77
+ # landmark_detect_false = 0
78
+
79
+ # # 遍历 txt 文件中的图片列表
80
+ # for image_filename in tqdm(image_list, desc=f'Processing {txt_filename}'):
81
+ # image_path = os.path.join('control_inversion', image_filename+'.png')
82
+ # if image_path.endswith('.png') or image_path.endswith('.jpg'):
83
+ # img_np = cv2.imread(image_path)
84
+
85
+ # # 进行人脸检测和关键点提取
86
+ # with torch.no_grad():
87
+ # predictor = KeypointExtractor('cuda')
88
+ # dets = predictor.det_net.detect_faces(img_np, 0.97)
89
+
90
+ # if len(dets) == 0:
91
+ # landmark_detect_false += 1
92
+ # else:
93
+ # landmark_detect_success += 1
94
+
95
+ # # 计算检测率
96
+ # detect_rate = landmark_detect_success / (landmark_detect_success + landmark_detect_false)
97
+ # print(f'{txt_filename}: Detect Rate = {detect_rate}')
98
+
99
+ # # 更新总的计数
100
+ # total_landmark_detect_success += landmark_detect_success
101
+ # total_landmark_detect_false += landmark_detect_false
102
+
103
+ # # 计算总的检测率
104
+ # total_detect_rate = total_landmark_detect_success / (total_landmark_detect_success + total_landmark_detect_false)
105
+ # print(f'Total Detect Rate = {total_detect_rate}')
106
+ import os
107
+ import sys
108
+ import cv2
109
+ import torch
110
+ from tqdm import tqdm
111
+ from chat_anything.sad_talker.face3d.extract_kp_videos_safe import KeypointExtractor
112
+
113
+ # 设置文件夹路径
114
+ folder_path = sys.argv[1]
115
+
116
+ # 初始化成功和失败的计数
117
+ total_landmark_detect_success = 0
118
+ total_landmark_detect_false = 0
119
+
120
+ # 遍历文件夹中的 txt 文件
121
+ for txt_filename in os.listdir(folder_path):
122
+ if txt_filename.endswith('.txt'):
123
+ txt_file_path = os.path.join(folder_path, txt_filename)
124
+
125
+ # # 读取 txt 文件中的图片列表
126
+ # with open(txt_file_path, 'r') as file:
127
+ # image_list = file.read().splitlines()
128
+ image_list = os.listdir(txt_file_path)
129
+ landmark_detect_success = 0
130
+ landmark_detect_false = 0
131
+
132
+ # 遍历 txt 文件中的图片列表
133
+ for image_filename in tqdm(image_list, desc=f'Processing {txt_filename}'):
134
+ image_path = os.path.join(txt_file_path, image_filename)
135
+ if image_path.endswith('.png') or image_path.endswith('.jpg'):
136
+ img_np = cv2.imread(image_path)
137
+
138
+ # 进行人脸检测和关键点提取
139
+ with torch.no_grad():
140
+ predictor = KeypointExtractor('cuda')
141
+ dets = predictor.det_net.detect_faces(img_np, 0.97)
142
+
143
+ if len(dets) == 0:
144
+ landmark_detect_false += 1
145
+ else:
146
+ landmark_detect_success += 1
147
+
148
+ # 计算检测率
149
+ detect_rate = landmark_detect_success / (landmark_detect_success + landmark_detect_false)
150
+ print(f'{txt_filename}: Detect Rate = {detect_rate}')
151
+
152
+ # 更新总的计数
153
+ total_landmark_detect_success += landmark_detect_success
154
+ total_landmark_detect_false += landmark_detect_false
155
+
156
+ # 计算总的检测率
157
+ total_detect_rate = total_landmark_detect_success / (total_landmark_detect_success + total_landmark_detect_false)
158
+ print(f'Total Detect Rate = {total_detect_rate}')