kittendev commited on
Commit
05672b0
·
verified ·
1 Parent(s): 312317c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 * Ltd. All rights reserved.
2
+ # author : Sanghyeon Jo <[email protected]>
3
+
4
+ import gradio as gr
5
+
6
+ import os
7
+ import sys
8
+ import copy
9
+ import shutil
10
+ import random
11
+ import argparse
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from torchvision import transforms
19
+ from torch.utils.tensorboard import SummaryWriter
20
+
21
+ from torch.utils.data import DataLoader
22
+
23
+ from core.puzzle_utils import *
24
+ from core.networks import *
25
+ from core.datasets import *
26
+
27
+ from tools.general.io_utils import *
28
+ from tools.general.time_utils import *
29
+ from tools.general.json_utils import *
30
+
31
+ from tools.ai.log_utils import *
32
+ from tools.ai.demo_utils import *
33
+ from tools.ai.optim_utils import *
34
+ from tools.ai.torch_utils import *
35
+ from tools.ai.evaluate_utils import *
36
+
37
+ from tools.ai.augment_utils import *
38
+ from tools.ai.randaugment import *
39
+
40
+ import PIL.Image
41
+
42
+ parser = argparse.ArgumentParser()
43
+
44
+ ###############################################################################
45
+ # Dataset
46
+ ###############################################################################
47
+ parser.add_argument('--seed', default=2606, type=int)
48
+ parser.add_argument('--num_workers', default=4, type=int)
49
+
50
+ ###############################################################################
51
+ # Network
52
+ ###############################################################################
53
+ parser.add_argument('--architecture', default='DeepLabv3+', type=str)
54
+ parser.add_argument('--backbone', default='resnet50', type=str)
55
+ parser.add_argument('--mode', default='fix', type=str)
56
+ parser.add_argument('--use_gn', default=True, type=str2bool)
57
+
58
+ ###############################################################################
59
+ # Inference parameters
60
+ ###############################################################################
61
+ parser.add_argument('--tag', default='', type=str)
62
+
63
+ parser.add_argument('--domain', default='val', type=str)
64
+
65
+ parser.add_argument('--scales', default='0.5,1.0,1.5,2.0', type=str)
66
+ parser.add_argument('--iteration', default=10, type=int)
67
+
68
+ if __name__ == '__main__':
69
+ ###################################################################################
70
+ # Arguments
71
+ ###################################################################################
72
+ args = parser.parse_args()
73
+
74
+ model_dir = create_directory('./experiments/models/')
75
+ model_path = model_dir + f'DeepLabv3+@ResNet-50@[email protected]'
76
+
77
+ if 'train' in args.domain:
78
+ args.tag += '@train'
79
+ else:
80
+ args.tag += '@' + args.domain
81
+
82
+ args.tag += '@scale=%s' % args.scales
83
+ args.tag += '@iteration=%d' % args.iteration
84
+
85
+ set_seed(args.seed)
86
+ log_func = lambda string='': print(string)
87
+
88
+ ###################################################################################
89
+ # Transform, Dataset, DataLoader
90
+ ###################################################################################
91
+ imagenet_mean = [0.485, 0.456, 0.406]
92
+ imagenet_std = [0.229, 0.224, 0.225]
93
+
94
+ normalize_fn = Normalize(imagenet_mean, imagenet_std)
95
+
96
+ # for mIoU
97
+ meta_dic = read_json('./data/VOC_2012.json')
98
+
99
+ ###################################################################################
100
+ # Network
101
+ ###################################################################################
102
+ if args.architecture == 'DeepLabv3+':
103
+ model = DeepLabv3_Plus(args.backbone, num_classes=meta_dic['classes'] + 1, mode=args.mode,
104
+ use_group_norm=args.use_gn)
105
+ elif args.architecture == 'Seg_Model':
106
+ model = Seg_Model(args.backbone, num_classes=meta_dic['classes'] + 1)
107
+ elif args.architecture == 'CSeg_Model':
108
+ model = CSeg_Model(args.backbone, num_classes=meta_dic['classes'] + 1)
109
+
110
+ model.eval()
111
+
112
+ log_func('[i] Architecture is {}'.format(args.architecture))
113
+ log_func('[i] Total Params: %.2fM' % (calculate_parameters(model)))
114
+ log_func()
115
+
116
+ load_model(model, model_path, parallel=False)
117
+
118
+ #################################################################################################
119
+ # Evaluation
120
+ #################################################################################################
121
+ eval_timer = Timer()
122
+ scales = [float(scale) for scale in args.scales.split(',')]
123
+
124
+ model.eval()
125
+ eval_timer.tik()
126
+
127
+
128
+ def inference(images, image_size):
129
+ logits = model(images)
130
+ logits = resize_for_tensors(logits, image_size)
131
+
132
+ logits = logits[0] + logits[1].flip(-1)
133
+ logits = get_numpy_from_tensor(logits).transpose((1, 2, 0))
134
+ return logits
135
+
136
+
137
+ def predict_image(ori_image):
138
+ ori_image = PIL.Image.fromarray(ori_image)
139
+ with torch.no_grad():
140
+ ori_w, ori_h = ori_image.size
141
+
142
+ cams_list = []
143
+
144
+ for scale in scales:
145
+ image = copy.deepcopy(ori_image)
146
+ image = image.resize((round(ori_w * scale), round(ori_h * scale)), resample=PIL.Image.BICUBIC)
147
+
148
+ image = normalize_fn(image)
149
+ image = image.transpose((2, 0, 1))
150
+
151
+ image = torch.from_numpy(image)
152
+ flipped_image = image.flip(-1)
153
+
154
+ images = torch.stack([image, flipped_image])
155
+
156
+ cams = inference(images, (ori_h, ori_w))
157
+ cams_list.append(cams)
158
+
159
+ preds = np.sum(cams_list, axis=0)
160
+ preds = F.softmax(torch.from_numpy(preds), dim=-1).numpy()
161
+
162
+ if args.iteration > 0:
163
+ preds = crf_inference(np.asarray(ori_image), preds.transpose((2, 0, 1)), t=args.iteration)
164
+ pred_mask = np.argmax(preds, axis=0)
165
+ else:
166
+ pred_mask = np.argmax(preds, axis=-1)
167
+
168
+ return pred_mask.astype(np.uint8)
169
+
170
+
171
+ demo = gr.Interface(
172
+ fn=predict_image,
173
+ inputs="image",
174
+ outputs="image"
175
+ )
176
+
177
+ demo.launch()