kittendev commited on
Commit
312317c
·
verified ·
1 Parent(s): 270b80e

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -176
app.py DELETED
@@ -1,176 +0,0 @@
1
- # Copyright (C) 2020 * Ltd. All rights reserved.
2
- # author : Sanghyeon Jo <[email protected]>
3
-
4
- import gradio as gr
5
-
6
- import os
7
- import sys
8
- import copy
9
- import shutil
10
- import random
11
- import argparse
12
- import numpy as np
13
-
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
-
18
- from torchvision import transforms
19
- from torch.utils.tensorboard import SummaryWriter
20
-
21
- from torch.utils.data import DataLoader
22
-
23
- from core.puzzle_utils import *
24
- from core.networks import *
25
- from core.datasets import *
26
-
27
- from tools.general.io_utils import *
28
- from tools.general.time_utils import *
29
- from tools.general.json_utils import *
30
-
31
- from tools.ai.log_utils import *
32
- from tools.ai.demo_utils import *
33
- from tools.ai.optim_utils import *
34
- from tools.ai.torch_utils import *
35
- from tools.ai.evaluate_utils import *
36
-
37
- from tools.ai.augment_utils import *
38
- from tools.ai.randaugment import *
39
-
40
- parser = argparse.ArgumentParser()
41
-
42
- ###############################################################################
43
- # Dataset
44
- ###############################################################################
45
- parser.add_argument('--seed', default=2606, type=int)
46
- parser.add_argument('--num_workers', default=4, type=int)
47
- parser.add_argument('--data_dir', default='../VOCtrainval_11-May-2012/', type=str)
48
-
49
- ###############################################################################
50
- # Network
51
- ###############################################################################
52
- parser.add_argument('--architecture', default='DeepLabv3+', type=str)
53
- parser.add_argument('--backbone', default='resnet50', type=str)
54
- parser.add_argument('--mode', default='fix', type=str)
55
- parser.add_argument('--use_gn', default=True, type=str2bool)
56
-
57
- ###############################################################################
58
- # Inference parameters
59
- ###############################################################################
60
- parser.add_argument('--tag', default='', type=str)
61
-
62
- parser.add_argument('--domain', default='val', type=str)
63
-
64
- parser.add_argument('--scales', default='0.5,1.0,1.5,2.0', type=str)
65
- parser.add_argument('--iteration', default=10, type=int)
66
-
67
- if __name__ == '__main__':
68
- ###################################################################################
69
- # Arguments
70
- ###################################################################################
71
- args = parser.parse_args()
72
-
73
- model_dir = create_directory('./experiments/models/')
74
- model_path = model_dir + f'DeepLabv3+@ResNet-50@[email protected]'
75
-
76
- if 'train' in args.domain:
77
- args.tag += '@train'
78
- else:
79
- args.tag += '@' + args.domain
80
-
81
- args.tag += '@scale=%s' % args.scales
82
- args.tag += '@iteration=%d' % args.iteration
83
-
84
- set_seed(args.seed)
85
- log_func = lambda string='': print(string)
86
-
87
- ###################################################################################
88
- # Transform, Dataset, DataLoader
89
- ###################################################################################
90
- imagenet_mean = [0.485, 0.456, 0.406]
91
- imagenet_std = [0.229, 0.224, 0.225]
92
-
93
- normalize_fn = Normalize(imagenet_mean, imagenet_std)
94
-
95
- # for mIoU
96
- meta_dic = read_json('./data/VOC_2012.json')
97
-
98
- ###################################################################################
99
- # Network
100
- ###################################################################################
101
- if args.architecture == 'DeepLabv3+':
102
- model = DeepLabv3_Plus(args.backbone, num_classes=meta_dic['classes'] + 1, mode=args.mode,
103
- use_group_norm=args.use_gn)
104
- elif args.architecture == 'Seg_Model':
105
- model = Seg_Model(args.backbone, num_classes=meta_dic['classes'] + 1)
106
- elif args.architecture == 'CSeg_Model':
107
- model = CSeg_Model(args.backbone, num_classes=meta_dic['classes'] + 1)
108
-
109
- model = model.cuda()
110
- model.eval()
111
-
112
- log_func('[i] Architecture is {}'.format(args.architecture))
113
- log_func('[i] Total Params: %.2fM' % (calculate_parameters(model)))
114
- log_func()
115
-
116
- load_model(model, model_path, parallel=False)
117
-
118
- #################################################################################################
119
- # Evaluation
120
- #################################################################################################
121
- eval_timer = Timer()
122
- scales = [float(scale) for scale in args.scales.split(',')]
123
-
124
- model.eval()
125
- eval_timer.tik()
126
-
127
-
128
- def inference(images, image_size):
129
- images = images.cuda()
130
-
131
- logits = model(images)
132
- logits = resize_for_tensors(logits, image_size)
133
-
134
- logits = logits[0] + logits[1].flip(-1)
135
- logits = get_numpy_from_tensor(logits).transpose((1, 2, 0))
136
- return logits
137
-
138
-
139
- def predict_image(ori_image):
140
- with torch.no_grad():
141
- ori_w, ori_h = ori_image.size
142
-
143
- cams_list = []
144
-
145
- for scale in scales:
146
- image = copy.deepcopy(ori_image)
147
- image = image.resize((round(ori_w * scale), round(ori_h * scale)), resample=PIL.Image.BICUBIC)
148
-
149
- image = normalize_fn(image)
150
- image = image.transpose((2, 0, 1))
151
-
152
- image = torch.from_numpy(image)
153
- flipped_image = image.flip(-1)
154
-
155
- images = torch.stack([image, flipped_image])
156
-
157
- cams = inference(images, (ori_h, ori_w))
158
- cams_list.append(cams)
159
-
160
- preds = np.sum(cams_list, axis=0)
161
- preds = F.softmax(torch.from_numpy(preds), dim=-1).numpy()
162
-
163
- if args.iteration > 0:
164
- preds = crf_inference(np.asarray(ori_image), preds.transpose((2, 0, 1)), t=args.iteration)
165
- pred_mask = np.argmax(preds, axis=0)
166
- else:
167
- pred_mask = np.argmax(preds, axis=-1)
168
-
169
- return pred_mask.astype(np.uint8)
170
-
171
-
172
- demo = gr.Interface(
173
- fn=predict_image,
174
- inputs="image",
175
- outputs="image"
176
- )