Spaces:
Paused
Paused
import numpy as np | |
from PIL import Image, ImageDraw, ImageFile | |
from .NetWork import VGG | |
import paddle | |
import cv2 | |
def get_color_map_list(num_classes): | |
""" | |
Args: | |
num_classes (int): number of class | |
Returns: | |
color_map (list): RGB color list | |
""" | |
color_map = num_classes * [0, 0, 0] | |
for i in range(0, num_classes): | |
j = 0 | |
lab = i | |
while lab: | |
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) | |
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) | |
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) | |
j += 1 | |
lab >>= 3 | |
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] | |
return color_map | |
def draw_det(image, dt_bboxes, name_set): | |
im = Image.fromarray(image) | |
draw_thickness = min(im.size) // 320 | |
draw = ImageDraw.Draw(im) | |
clsid2color = {} | |
color_list = get_color_map_list(len(name_set)) | |
for (cls_id, score, xmin, ymin, xmax, ymax) in dt_bboxes: | |
image_box = im.crop(tuple([xmin, ymin, xmax, ymax])) | |
label = emotic(image_box) | |
cls_id = int(cls_id) | |
color = tuple(color_list[cls_id]) | |
# draw bbox | |
draw.line( | |
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), | |
(xmin, ymin)], | |
width=draw_thickness, | |
fill=color) | |
# draw label | |
text = "{} {:.4f}".format(label, score) | |
box = draw.textbbox((xmin, ymin), text, anchor='lt') | |
draw.rectangle(box, fill=color) | |
draw.text((box[0], box[1]), text, fill=(255, 255, 255)) | |
image = np.array(im) | |
return image | |
def emotic(image): | |
def load_image(img): | |
# 将图片尺寸缩放道 224x224 | |
img = cv2.resize(img, (224, 224)) | |
# 读入的图像数据格式是[H, W, C] | |
# 使用转置操作将其变成[C, H, W] | |
img = np.transpose(img, (2, 0, 1)) | |
img = img.astype('float32') | |
# 将数据范围调整到[-1.0, 1.0]之间 | |
img = img / 255. | |
img = img * 2.0 - 1.0 | |
return img | |
model = VGG(num_class=7) | |
params_file_path = r'configs/vgg.pdparams' | |
img = np.array(image) | |
# plt.imshow(img) | |
# plt.axis('off') | |
# plt.show() | |
param_dict = paddle.load(params_file_path) | |
model.load_dict(param_dict) | |
# 灌入数据 | |
# model.eval() | |
tensor_img = load_image(img) | |
tensor_img = np.expand_dims(tensor_img, 0) | |
results = model(paddle.to_tensor(tensor_img)) | |
# 取概率最大的标签作为预测输出 | |
lab = np.argsort(results.numpy()) | |
tap = lab[0][-1] | |
if tap == 0: | |
return 'SAD' | |
elif tap == 1: | |
return 'DISGUST' | |
elif tap == 2: | |
return 'HAPPY' | |
elif tap == 3: | |
return 'FEAR' | |
elif tap == 4: | |
return 'SUPERISE' | |
elif tap == 5: | |
return 'NATUREAL' | |
elif tap == 6: | |
return 'ANGRY' | |
else: | |
raise ('Not excepted file name') | |