Spaces:
Sleeping
Sleeping
#!/usr/bin/python | |
# -*- encoding: utf-8 -*- | |
import numpy as np | |
from model import BiSeNet | |
import torch | |
import os | |
import os.path as osp | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import cv2 | |
from pathlib import Path | |
import configargparse | |
import tqdm | |
# import ttach as tta | |
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', | |
img_size=(512, 512)): | |
im = np.array(im) | |
vis_im = im.copy().astype(np.uint8) | |
vis_parsing_anno = parsing_anno.copy().astype(np.uint8) | |
vis_parsing_anno = cv2.resize( | |
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) | |
vis_parsing_anno_color = np.zeros( | |
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255 | |
num_of_class = np.max(vis_parsing_anno) | |
# print(num_of_class) | |
for pi in range(1, 14): | |
index = np.where(vis_parsing_anno == pi) | |
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) | |
for pi in range(14, 16): | |
index = np.where(vis_parsing_anno == pi) | |
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) | |
for pi in range(16, 17): | |
index = np.where(vis_parsing_anno == pi) | |
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) | |
for pi in range(17, num_of_class+1): | |
index = np.where(vis_parsing_anno == pi) | |
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) | |
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) | |
index = np.where(vis_parsing_anno == num_of_class-1) | |
vis_im = cv2.resize(vis_parsing_anno_color, img_size, | |
interpolation=cv2.INTER_NEAREST) | |
if save_im: | |
cv2.imwrite(save_path, vis_im) | |
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): | |
Path(respth).mkdir(parents=True, exist_ok=True) | |
print(f'[INFO] loading model...') | |
n_classes = 19 | |
net = BiSeNet(n_classes=n_classes) | |
net.cuda() | |
net.load_state_dict(torch.load(cp)) | |
net.eval() | |
to_tensor = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
image_paths = os.listdir(dspth) | |
with torch.no_grad(): | |
for image_path in tqdm.tqdm(image_paths): | |
if image_path.endswith('.jpg') or image_path.endswith('.png'): | |
img = Image.open(osp.join(dspth, image_path)) | |
ori_size = img.size | |
image = img.resize((512, 512), Image.BILINEAR) | |
image = image.convert("RGB") | |
img = to_tensor(image) | |
# test-time augmentation. | |
inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512] | |
outputs = net(inputs.cuda()) | |
parsing = outputs.mean(0).cpu().numpy().argmax(0) | |
image_path = int(image_path[:-4]) | |
image_path = str(image_path) + '.png' | |
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) | |
if __name__ == "__main__": | |
parser = configargparse.ArgumentParser() | |
parser.add_argument('--respath', type=str, default='./result/', help='result path for label') | |
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') | |
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') | |
args = parser.parse_args() | |
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) | |