Spaces:
Sleeping
Sleeping
# Copyright (C) 2020 * Ltd. All rights reserved. | |
# author : Sanghyeon Jo <[email protected]> | |
import gradio as gr | |
import os | |
import sys | |
import copy | |
import shutil | |
import random | |
import argparse | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.utils.data import DataLoader | |
from core.puzzle_utils import * | |
from core.networks import * | |
from core.datasets import * | |
from tools.general.io_utils import * | |
from tools.general.time_utils import * | |
from tools.general.json_utils import * | |
from tools.ai.log_utils import * | |
from tools.ai.demo_utils import * | |
from tools.ai.optim_utils import * | |
from tools.ai.torch_utils import * | |
from tools.ai.evaluate_utils import * | |
from tools.ai.augment_utils import * | |
from tools.ai.randaugment import * | |
import PIL.Image | |
parser = argparse.ArgumentParser() | |
############################################################################### | |
# Dataset | |
############################################################################### | |
parser.add_argument('--seed', default=2606, type=int) | |
parser.add_argument('--num_workers', default=4, type=int) | |
############################################################################### | |
# Network | |
############################################################################### | |
parser.add_argument('--architecture', default='DeepLabv3+', type=str) | |
parser.add_argument('--backbone', default='resnet50', type=str) | |
parser.add_argument('--mode', default='fix', type=str) | |
parser.add_argument('--use_gn', default=True, type=str2bool) | |
############################################################################### | |
# Inference parameters | |
############################################################################### | |
parser.add_argument('--tag', default='', type=str) | |
parser.add_argument('--domain', default='val', type=str) | |
parser.add_argument('--scales', default='0.5,1.0,1.5,2.0', type=str) | |
parser.add_argument('--iteration', default=10, type=int) | |
class_names = [ | |
"aeroplane", | |
"bicycle", | |
"bird", | |
"boat", | |
"bottle", | |
"bus", | |
"car", | |
"cat", | |
"chair", | |
"cow", | |
"diningtable", | |
"dog", | |
"horse", | |
"motorbike", | |
"person", | |
"pottedplant", | |
"sheep", | |
"sofa", | |
"train", | |
"tvmonitor" | |
] | |
cmap_dic = { | |
"background": [ | |
0, | |
0, | |
0 | |
], | |
"aeroplane": [ | |
128, | |
0, | |
0 | |
], | |
"bicycle": [ | |
0, | |
128, | |
0 | |
], | |
"bird": [ | |
128, | |
128, | |
0 | |
], | |
"boat": [ | |
0, | |
0, | |
128 | |
], | |
"bottle": [ | |
128, | |
0, | |
128 | |
], | |
"bus": [ | |
0, | |
128, | |
128 | |
], | |
"car": [ | |
128, | |
128, | |
128 | |
], | |
"cat": [ | |
64, | |
0, | |
0 | |
], | |
"chair": [ | |
192, | |
0, | |
0 | |
], | |
"cow": [ | |
64, | |
128, | |
0 | |
], | |
"diningtable": [ | |
192, | |
128, | |
0 | |
], | |
"dog": [ | |
64, | |
0, | |
128 | |
], | |
"horse": [ | |
192, | |
0, | |
128 | |
], | |
"motorbike": [ | |
64, | |
128, | |
128 | |
], | |
"person": [ | |
192, | |
128, | |
128 | |
], | |
"pottedplant": [ | |
0, | |
64, | |
0 | |
], | |
"sheep": [ | |
128, | |
64, | |
0 | |
], | |
"sofa": [ | |
0, | |
192, | |
0 | |
], | |
"train": [ | |
128, | |
192, | |
0 | |
], | |
"tvmonitor": [ | |
0, | |
64, | |
128 | |
] | |
} | |
colors = np.asarray([cmap_dic[class_name] for class_name in class_names]) | |
if __name__ == '__main__': | |
################################################################################### | |
# Arguments | |
################################################################################### | |
args = parser.parse_args() | |
model_dir = create_directory('./experiments/models/') | |
model_path = model_dir + f'DeepLabv3+@ResNet-50@[email protected]' | |
if 'train' in args.domain: | |
args.tag += '@train' | |
else: | |
args.tag += '@' + args.domain | |
args.tag += '@scale=%s' % args.scales | |
args.tag += '@iteration=%d' % args.iteration | |
set_seed(args.seed) | |
log_func = lambda string='': print(string) | |
################################################################################### | |
# Transform, Dataset, DataLoader | |
################################################################################### | |
imagenet_mean = [0.485, 0.456, 0.406] | |
imagenet_std = [0.229, 0.224, 0.225] | |
normalize_fn = Normalize(imagenet_mean, imagenet_std) | |
# for mIoU | |
meta_dic = read_json('./data/VOC_2012.json') | |
################################################################################### | |
# Network | |
################################################################################### | |
if args.architecture == 'DeepLabv3+': | |
model = DeepLabv3_Plus(args.backbone, num_classes=meta_dic['classes'] + 1, mode=args.mode, | |
use_group_norm=args.use_gn) | |
elif args.architecture == 'Seg_Model': | |
model = Seg_Model(args.backbone, num_classes=meta_dic['classes'] + 1) | |
elif args.architecture == 'CSeg_Model': | |
model = CSeg_Model(args.backbone, num_classes=meta_dic['classes'] + 1) | |
model.eval() | |
log_func('[i] Architecture is {}'.format(args.architecture)) | |
log_func('[i] Total Params: %.2fM' % (calculate_parameters(model))) | |
log_func() | |
load_model(model, model_path, parallel=False) | |
################################################################################################# | |
# Evaluation | |
################################################################################################# | |
eval_timer = Timer() | |
scales = [float(scale) for scale in args.scales.split(',')] | |
model.eval() | |
eval_timer.tik() | |
def inference(images, image_size): | |
logits = model(images) | |
logits = resize_for_tensors(logits, image_size) | |
logits = logits[0] + logits[1].flip(-1) | |
logits = get_numpy_from_tensor(logits).transpose((1, 2, 0)) | |
return logits | |
def predict_image(ori_image): | |
ori_image = PIL.Image.fromarray(ori_image) | |
with torch.no_grad(): | |
ori_w, ori_h = ori_image.size | |
cams_list = [] | |
for scale in scales: | |
image = copy.deepcopy(ori_image) | |
image = image.resize((round(ori_w * scale), round(ori_h * scale)), resample=PIL.Image.BICUBIC) | |
image = normalize_fn(image) | |
image = image.transpose((2, 0, 1)) | |
image = torch.from_numpy(image) | |
flipped_image = image.flip(-1) | |
images = torch.stack([image, flipped_image]) | |
cams = inference(images, (ori_h, ori_w)) | |
cams_list.append(cams) | |
preds = np.sum(cams_list, axis=0) | |
preds = F.softmax(torch.from_numpy(preds), dim=-1).numpy() | |
if args.iteration > 0: | |
preds = crf_inference(np.asarray(ori_image), preds.transpose((2, 0, 1)), t=args.iteration) | |
pred_mask = np.argmax(preds, axis=0) | |
else: | |
pred_mask = np.argmax(preds, axis=-1) | |
pred_mask = decode_from_colormap(pred_mask, colors)[..., ::-1] | |
return Image.fromarray(pred_mask.astype(np.uint8)).convert("RGB") | |
demo = gr.Interface( | |
fn=predict_image, | |
inputs="image", | |
outputs="image" | |
) | |
demo.launch() | |