# Copyright (C) 2020 * Ltd. All rights reserved. # author : Sanghyeon Jo import gradio as gr import os import sys import copy import shutil import random import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader from core.puzzle_utils import * from core.networks import * from core.datasets import * from tools.general.io_utils import * from tools.general.time_utils import * from tools.general.json_utils import * from tools.ai.log_utils import * from tools.ai.demo_utils import * from tools.ai.optim_utils import * from tools.ai.torch_utils import * from tools.ai.evaluate_utils import * from tools.ai.augment_utils import * from tools.ai.randaugment import * import PIL.Image parser = argparse.ArgumentParser() ############################################################################### # Dataset ############################################################################### parser.add_argument('--seed', default=2606, type=int) parser.add_argument('--num_workers', default=4, type=int) ############################################################################### # Network ############################################################################### parser.add_argument('--architecture', default='DeepLabv3+', type=str) parser.add_argument('--backbone', default='resnet50', type=str) parser.add_argument('--mode', default='fix', type=str) parser.add_argument('--use_gn', default=True, type=str2bool) ############################################################################### # Inference parameters ############################################################################### parser.add_argument('--tag', default='', type=str) parser.add_argument('--domain', default='val', type=str) parser.add_argument('--scales', default='0.5,1.0,1.5,2.0', type=str) parser.add_argument('--iteration', default=10, type=int) class_names = [ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" ] cmap_dic = { "background": [ 0, 0, 0 ], "aeroplane": [ 128, 0, 0 ], "bicycle": [ 0, 128, 0 ], "bird": [ 128, 128, 0 ], "boat": [ 0, 0, 128 ], "bottle": [ 128, 0, 128 ], "bus": [ 0, 128, 128 ], "car": [ 128, 128, 128 ], "cat": [ 64, 0, 0 ], "chair": [ 192, 0, 0 ], "cow": [ 64, 128, 0 ], "diningtable": [ 192, 128, 0 ], "dog": [ 64, 0, 128 ], "horse": [ 192, 0, 128 ], "motorbike": [ 64, 128, 128 ], "person": [ 192, 128, 128 ], "pottedplant": [ 0, 64, 0 ], "sheep": [ 128, 64, 0 ], "sofa": [ 0, 192, 0 ], "train": [ 128, 192, 0 ], "tvmonitor": [ 0, 64, 128 ] } colors = np.asarray([cmap_dic[class_name] for class_name in class_names]) if __name__ == '__main__': ################################################################################### # Arguments ################################################################################### args = parser.parse_args() model_dir = create_directory('./experiments/models/') model_path = model_dir + f'DeepLabv3+@ResNet-50@Fix@GN.pth' if 'train' in args.domain: args.tag += '@train' else: args.tag += '@' + args.domain args.tag += '@scale=%s' % args.scales args.tag += '@iteration=%d' % args.iteration set_seed(args.seed) log_func = lambda string='': print(string) ################################################################################### # Transform, Dataset, DataLoader ################################################################################### imagenet_mean = [0.485, 0.456, 0.406] imagenet_std = [0.229, 0.224, 0.225] normalize_fn = Normalize(imagenet_mean, imagenet_std) # for mIoU meta_dic = read_json('./data/VOC_2012.json') ################################################################################### # Network ################################################################################### if args.architecture == 'DeepLabv3+': model = DeepLabv3_Plus(args.backbone, num_classes=meta_dic['classes'] + 1, mode=args.mode, use_group_norm=args.use_gn) elif args.architecture == 'Seg_Model': model = Seg_Model(args.backbone, num_classes=meta_dic['classes'] + 1) elif args.architecture == 'CSeg_Model': model = CSeg_Model(args.backbone, num_classes=meta_dic['classes'] + 1) model.eval() log_func('[i] Architecture is {}'.format(args.architecture)) log_func('[i] Total Params: %.2fM' % (calculate_parameters(model))) log_func() load_model(model, model_path, parallel=False) ################################################################################################# # Evaluation ################################################################################################# eval_timer = Timer() scales = [float(scale) for scale in args.scales.split(',')] model.eval() eval_timer.tik() def inference(images, image_size): logits = model(images) logits = resize_for_tensors(logits, image_size) logits = logits[0] + logits[1].flip(-1) logits = get_numpy_from_tensor(logits).transpose((1, 2, 0)) return logits def predict_image(ori_image): ori_image = PIL.Image.fromarray(ori_image) with torch.no_grad(): ori_w, ori_h = ori_image.size cams_list = [] for scale in scales: image = copy.deepcopy(ori_image) image = image.resize((round(ori_w * scale), round(ori_h * scale)), resample=PIL.Image.BICUBIC) image = normalize_fn(image) image = image.transpose((2, 0, 1)) image = torch.from_numpy(image) flipped_image = image.flip(-1) images = torch.stack([image, flipped_image]) cams = inference(images, (ori_h, ori_w)) cams_list.append(cams) preds = np.sum(cams_list, axis=0) preds = F.softmax(torch.from_numpy(preds), dim=-1).numpy() if args.iteration > 0: preds = crf_inference(np.asarray(ori_image), preds.transpose((2, 0, 1)), t=args.iteration) pred_mask = np.argmax(preds, axis=0) else: pred_mask = np.argmax(preds, axis=-1) pred_mask = decode_from_colormap(pred_mask, colors)[..., ::-1] return Image.fromarray(pred_mask.astype(np.uint8)).convert("RGB") demo = gr.Interface( fn=predict_image, inputs="image", outputs="image" ) demo.launch()