Spaces:
Running
Running
Update AV/config/config_test_general.py
Browse files- AV/config/config_test_general.py +137 -134
AV/config/config_test_general.py
CHANGED
@@ -1,134 +1,137 @@
|
|
1 |
-
import torch
|
2 |
-
import os
|
3 |
-
|
4 |
-
# Check GPU availability
|
5 |
-
use_cuda = torch.cuda.is_available()
|
6 |
-
gpu_ids = [0] if use_cuda else []
|
7 |
-
device = torch.device('cuda' if use_cuda else 'cpu')
|
8 |
-
|
9 |
-
dataset_name = 'all' # DRIVE
|
10 |
-
#dataset_name = 'LES' # LES
|
11 |
-
# dataset_name = 'hrf' # HRF
|
12 |
-
# dataset_name = 'ukbb' # UKBB
|
13 |
-
# dataset_name = 'all'
|
14 |
-
dataset = dataset_name
|
15 |
-
max_step = 30000 # 30000 for ukbb
|
16 |
-
|
17 |
-
batch_size = 8 # default: 4
|
18 |
-
print_iter = 100 # default: 100
|
19 |
-
display_iter = 100 # default: 100
|
20 |
-
save_iter = 5000 # default: 5000
|
21 |
-
first_display_metric_iter = max_step - save_iter # default: 25000
|
22 |
-
lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002
|
23 |
-
step_size = 7000 # 7000 for DRIVE
|
24 |
-
lr_decay_gamma = 0.5 # default: 0.5
|
25 |
-
use_SGD = False # default:False
|
26 |
-
|
27 |
-
input_nc = 3
|
28 |
-
ndf = 32
|
29 |
-
netD_type = 'basic'
|
30 |
-
n_layers_D = 5
|
31 |
-
norm = 'instance'
|
32 |
-
no_lsgan = False
|
33 |
-
init_type = 'normal'
|
34 |
-
init_gain = 0.02
|
35 |
-
use_sigmoid = no_lsgan
|
36 |
-
use_noise_input_D = False
|
37 |
-
use_dropout_D = False
|
38 |
-
# torch.cuda.set_device(gpu_ids[0])
|
39 |
-
use_GAN = True # default: True
|
40 |
-
|
41 |
-
# adam
|
42 |
-
beta1 = 0.5
|
43 |
-
|
44 |
-
# settings for GAN loss
|
45 |
-
num_classes_D = 1
|
46 |
-
lambda_GAN_D = 0.01
|
47 |
-
lambda_GAN_G = 0.01
|
48 |
-
lambda_GAN_gp = 100
|
49 |
-
lambda_BCE = 5
|
50 |
-
lambda_DICE = 5
|
51 |
-
|
52 |
-
input_nc_D = input_nc + 3
|
53 |
-
|
54 |
-
# settings for centerness
|
55 |
-
use_centerness = True # default: True
|
56 |
-
lambda_centerness = 1
|
57 |
-
center_loss_type = 'centerness'
|
58 |
-
centerness_map_size = [128, 128]
|
59 |
-
|
60 |
-
# pretrained model
|
61 |
-
use_pretrained_G = True
|
62 |
-
use_pretrained_D = False
|
63 |
-
# model_path_pretrained_G = './log/patch_pretrain'
|
64 |
-
model_path_pretrained_G = ''
|
65 |
-
model_step_pretrained_G = 0
|
66 |
-
stride_height = 0
|
67 |
-
stride_width = 0
|
68 |
-
patch_size_list=[]
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
patch_size_list = [64,
|
96 |
-
elif dataset_name == '
|
97 |
-
patch_size_list = [96, 384, 256]
|
98 |
-
|
99 |
-
patch_size_list = [
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
|
4 |
+
# Check GPU availability
|
5 |
+
use_cuda = torch.cuda.is_available()
|
6 |
+
gpu_ids = [0] if use_cuda else []
|
7 |
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
8 |
+
|
9 |
+
dataset_name = 'all' # DRIVE
|
10 |
+
#dataset_name = 'LES' # LES
|
11 |
+
# dataset_name = 'hrf' # HRF
|
12 |
+
# dataset_name = 'ukbb' # UKBB
|
13 |
+
# dataset_name = 'all'
|
14 |
+
dataset = dataset_name
|
15 |
+
max_step = 30000 # 30000 for ukbb
|
16 |
+
|
17 |
+
batch_size = 8 # default: 4
|
18 |
+
print_iter = 100 # default: 100
|
19 |
+
display_iter = 100 # default: 100
|
20 |
+
save_iter = 5000 # default: 5000
|
21 |
+
first_display_metric_iter = max_step - save_iter # default: 25000
|
22 |
+
lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002
|
23 |
+
step_size = 7000 # 7000 for DRIVE
|
24 |
+
lr_decay_gamma = 0.5 # default: 0.5
|
25 |
+
use_SGD = False # default:False
|
26 |
+
|
27 |
+
input_nc = 3
|
28 |
+
ndf = 32
|
29 |
+
netD_type = 'basic'
|
30 |
+
n_layers_D = 5
|
31 |
+
norm = 'instance'
|
32 |
+
no_lsgan = False
|
33 |
+
init_type = 'normal'
|
34 |
+
init_gain = 0.02
|
35 |
+
use_sigmoid = no_lsgan
|
36 |
+
use_noise_input_D = False
|
37 |
+
use_dropout_D = False
|
38 |
+
# torch.cuda.set_device(gpu_ids[0])
|
39 |
+
use_GAN = True # default: True
|
40 |
+
|
41 |
+
# adam
|
42 |
+
beta1 = 0.5
|
43 |
+
|
44 |
+
# settings for GAN loss
|
45 |
+
num_classes_D = 1
|
46 |
+
lambda_GAN_D = 0.01
|
47 |
+
lambda_GAN_G = 0.01
|
48 |
+
lambda_GAN_gp = 100
|
49 |
+
lambda_BCE = 5
|
50 |
+
lambda_DICE = 5
|
51 |
+
|
52 |
+
input_nc_D = input_nc + 3
|
53 |
+
|
54 |
+
# settings for centerness
|
55 |
+
use_centerness = True # default: True
|
56 |
+
lambda_centerness = 1
|
57 |
+
center_loss_type = 'centerness'
|
58 |
+
centerness_map_size = [128, 128]
|
59 |
+
|
60 |
+
# pretrained model
|
61 |
+
use_pretrained_G = True
|
62 |
+
use_pretrained_D = False
|
63 |
+
# model_path_pretrained_G = './log/patch_pretrain'
|
64 |
+
model_path_pretrained_G = ''
|
65 |
+
model_step_pretrained_G = 0
|
66 |
+
stride_height = 0
|
67 |
+
stride_width = 0
|
68 |
+
patch_size_list=[]
|
69 |
+
use_CAM = False
|
70 |
+
|
71 |
+
#use resize
|
72 |
+
use_resize = False
|
73 |
+
resize_w_h = (1920,512)
|
74 |
+
def set_dataset(name):
|
75 |
+
global dataset_name, model_path_pretrained_G, model_step_pretrained_G
|
76 |
+
global stride_height, stride_width,patch_size,patch_size_list,dataset,use_CAM,use_resize,resize_w_h
|
77 |
+
dataset_name = name
|
78 |
+
dataset = name
|
79 |
+
if dataset_name == 'DRIVE':
|
80 |
+
model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)'
|
81 |
+
model_step_pretrained_G = 6500
|
82 |
+
elif dataset_name == 'LES':
|
83 |
+
model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)'
|
84 |
+
model_step_pretrained_G = 0
|
85 |
+
elif dataset_name == 'hrf':
|
86 |
+
model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)'
|
87 |
+
model_step_pretrained_G = 1500
|
88 |
+
elif dataset_name == 'ukbb':
|
89 |
+
model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)'
|
90 |
+
model_step_pretrained_G = 5000
|
91 |
+
else:
|
92 |
+
model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)'
|
93 |
+
model_step_pretrained_G = 9000
|
94 |
+
if dataset_name == 'DRIVE':
|
95 |
+
patch_size_list = [64, 128, 256]
|
96 |
+
elif dataset_name == 'LES':
|
97 |
+
patch_size_list = [96, 384, 256]
|
98 |
+
elif dataset_name == 'hrf':
|
99 |
+
patch_size_list = [64, 384, 256]
|
100 |
+
elif dataset_name == 'ukbb':
|
101 |
+
patch_size_list = [96, 384, 256]
|
102 |
+
else:
|
103 |
+
patch_size_list = [96, 384, 512]
|
104 |
+
patch_size = patch_size_list[2]
|
105 |
+
|
106 |
+
# path for dataset
|
107 |
+
if dataset_name == 'DRIVE' or dataset_name == 'LES' or dataset_name == 'hrf':
|
108 |
+
stride_height = 50
|
109 |
+
stride_width = 50
|
110 |
+
elif dataset_name == 'ukbb':
|
111 |
+
stride_height = 150
|
112 |
+
stride_width = 150
|
113 |
+
else:
|
114 |
+
use_CAM=True
|
115 |
+
stride_height = 150
|
116 |
+
stride_width = 150
|
117 |
+
|
118 |
+
|
119 |
+
n_classes = 3
|
120 |
+
|
121 |
+
model_step = 0
|
122 |
+
|
123 |
+
|
124 |
+
# use av_cross
|
125 |
+
use_av_cross = False
|
126 |
+
|
127 |
+
use_high_semantic = False
|
128 |
+
lambda_high = 1 # A,V,Vessel
|
129 |
+
|
130 |
+
# use global semantic
|
131 |
+
use_global_semantic = False
|
132 |
+
global_warmup_step = 0 if use_pretrained_G else 5000
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|