weidai00 commited on
Commit
a0dc447
·
verified ·
1 Parent(s): 3cb6d91

Update AV/config/config_test_general.py

Browse files
Files changed (1) hide show
  1. AV/config/config_test_general.py +137 -134
AV/config/config_test_general.py CHANGED
@@ -1,134 +1,137 @@
1
- import torch
2
- import os
3
-
4
- # Check GPU availability
5
- use_cuda = torch.cuda.is_available()
6
- gpu_ids = [0] if use_cuda else []
7
- device = torch.device('cuda' if use_cuda else 'cpu')
8
-
9
- dataset_name = 'all' # DRIVE
10
- #dataset_name = 'LES' # LES
11
- # dataset_name = 'hrf' # HRF
12
- # dataset_name = 'ukbb' # UKBB
13
- # dataset_name = 'all'
14
- dataset = dataset_name
15
- max_step = 30000 # 30000 for ukbb
16
-
17
- batch_size = 8 # default: 4
18
- print_iter = 100 # default: 100
19
- display_iter = 100 # default: 100
20
- save_iter = 5000 # default: 5000
21
- first_display_metric_iter = max_step - save_iter # default: 25000
22
- lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002
23
- step_size = 7000 # 7000 for DRIVE
24
- lr_decay_gamma = 0.5 # default: 0.5
25
- use_SGD = False # default:False
26
-
27
- input_nc = 3
28
- ndf = 32
29
- netD_type = 'basic'
30
- n_layers_D = 5
31
- norm = 'instance'
32
- no_lsgan = False
33
- init_type = 'normal'
34
- init_gain = 0.02
35
- use_sigmoid = no_lsgan
36
- use_noise_input_D = False
37
- use_dropout_D = False
38
- # torch.cuda.set_device(gpu_ids[0])
39
- use_GAN = True # default: True
40
-
41
- # adam
42
- beta1 = 0.5
43
-
44
- # settings for GAN loss
45
- num_classes_D = 1
46
- lambda_GAN_D = 0.01
47
- lambda_GAN_G = 0.01
48
- lambda_GAN_gp = 100
49
- lambda_BCE = 5
50
- lambda_DICE = 5
51
-
52
- input_nc_D = input_nc + 3
53
-
54
- # settings for centerness
55
- use_centerness = True # default: True
56
- lambda_centerness = 1
57
- center_loss_type = 'centerness'
58
- centerness_map_size = [128, 128]
59
-
60
- # pretrained model
61
- use_pretrained_G = True
62
- use_pretrained_D = False
63
- # model_path_pretrained_G = './log/patch_pretrain'
64
- model_path_pretrained_G = ''
65
- model_step_pretrained_G = 0
66
- stride_height = 0
67
- stride_width = 0
68
- patch_size_list=[]
69
-
70
- def set_dataset(name):
71
- global dataset_name, model_path_pretrained_G, model_step_pretrained_G
72
- global stride_height, stride_width,patch_size,patch_size_list,dataset
73
- dataset_name = name
74
- dataset = name
75
- if dataset_name == 'DRIVE':
76
- model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)'
77
- model_step_pretrained_G = 6500
78
- elif dataset_name == 'LES':
79
- model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)'
80
- model_step_pretrained_G = 0
81
- elif dataset_name == 'hrf':
82
- model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)'
83
- model_step_pretrained_G = 1500
84
- elif dataset_name == 'ukbb':
85
- model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)'
86
- model_step_pretrained_G = 5000
87
- else:
88
- model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)'
89
- model_step_pretrained_G = 9000
90
- if dataset_name == 'DRIVE':
91
- patch_size_list = [64, 128, 256]
92
- elif dataset_name == 'LES':
93
- patch_size_list = [96, 384, 256]
94
- elif dataset_name == 'hrf':
95
- patch_size_list = [64, 384, 256]
96
- elif dataset_name == 'ukbb':
97
- patch_size_list = [96, 384, 256]
98
- else:
99
- patch_size_list = [96, 384, 512]
100
- patch_size = patch_size_list[2]
101
-
102
- # path for dataset
103
- if dataset_name == 'DRIVE' or dataset_name == 'LES' or dataset_name == 'hrf':
104
- stride_height = 50
105
- stride_width = 50
106
- else:
107
- stride_height = 150
108
- stride_width = 150
109
-
110
- n_classes = 3
111
-
112
- model_step = 0
113
-
114
- # use CAM
115
- use_CAM = False
116
-
117
- #use resize
118
- use_resize = True
119
- resize_w_h = (1920,512)
120
-
121
- # use av_cross
122
- use_av_cross = False
123
-
124
- use_high_semantic = False
125
- lambda_high = 1 # A,V,Vessel
126
-
127
- # use global semantic
128
- use_global_semantic = False
129
- global_warmup_step = 0 if use_pretrained_G else 5000
130
-
131
-
132
-
133
-
134
-
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ # Check GPU availability
5
+ use_cuda = torch.cuda.is_available()
6
+ gpu_ids = [0] if use_cuda else []
7
+ device = torch.device('cuda' if use_cuda else 'cpu')
8
+
9
+ dataset_name = 'all' # DRIVE
10
+ #dataset_name = 'LES' # LES
11
+ # dataset_name = 'hrf' # HRF
12
+ # dataset_name = 'ukbb' # UKBB
13
+ # dataset_name = 'all'
14
+ dataset = dataset_name
15
+ max_step = 30000 # 30000 for ukbb
16
+
17
+ batch_size = 8 # default: 4
18
+ print_iter = 100 # default: 100
19
+ display_iter = 100 # default: 100
20
+ save_iter = 5000 # default: 5000
21
+ first_display_metric_iter = max_step - save_iter # default: 25000
22
+ lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002
23
+ step_size = 7000 # 7000 for DRIVE
24
+ lr_decay_gamma = 0.5 # default: 0.5
25
+ use_SGD = False # default:False
26
+
27
+ input_nc = 3
28
+ ndf = 32
29
+ netD_type = 'basic'
30
+ n_layers_D = 5
31
+ norm = 'instance'
32
+ no_lsgan = False
33
+ init_type = 'normal'
34
+ init_gain = 0.02
35
+ use_sigmoid = no_lsgan
36
+ use_noise_input_D = False
37
+ use_dropout_D = False
38
+ # torch.cuda.set_device(gpu_ids[0])
39
+ use_GAN = True # default: True
40
+
41
+ # adam
42
+ beta1 = 0.5
43
+
44
+ # settings for GAN loss
45
+ num_classes_D = 1
46
+ lambda_GAN_D = 0.01
47
+ lambda_GAN_G = 0.01
48
+ lambda_GAN_gp = 100
49
+ lambda_BCE = 5
50
+ lambda_DICE = 5
51
+
52
+ input_nc_D = input_nc + 3
53
+
54
+ # settings for centerness
55
+ use_centerness = True # default: True
56
+ lambda_centerness = 1
57
+ center_loss_type = 'centerness'
58
+ centerness_map_size = [128, 128]
59
+
60
+ # pretrained model
61
+ use_pretrained_G = True
62
+ use_pretrained_D = False
63
+ # model_path_pretrained_G = './log/patch_pretrain'
64
+ model_path_pretrained_G = ''
65
+ model_step_pretrained_G = 0
66
+ stride_height = 0
67
+ stride_width = 0
68
+ patch_size_list=[]
69
+ use_CAM = False
70
+
71
+ #use resize
72
+ use_resize = False
73
+ resize_w_h = (1920,512)
74
+ def set_dataset(name):
75
+ global dataset_name, model_path_pretrained_G, model_step_pretrained_G
76
+ global stride_height, stride_width,patch_size,patch_size_list,dataset,use_CAM,use_resize,resize_w_h
77
+ dataset_name = name
78
+ dataset = name
79
+ if dataset_name == 'DRIVE':
80
+ model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)'
81
+ model_step_pretrained_G = 6500
82
+ elif dataset_name == 'LES':
83
+ model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)'
84
+ model_step_pretrained_G = 0
85
+ elif dataset_name == 'hrf':
86
+ model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)'
87
+ model_step_pretrained_G = 1500
88
+ elif dataset_name == 'ukbb':
89
+ model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)'
90
+ model_step_pretrained_G = 5000
91
+ else:
92
+ model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)'
93
+ model_step_pretrained_G = 9000
94
+ if dataset_name == 'DRIVE':
95
+ patch_size_list = [64, 128, 256]
96
+ elif dataset_name == 'LES':
97
+ patch_size_list = [96, 384, 256]
98
+ elif dataset_name == 'hrf':
99
+ patch_size_list = [64, 384, 256]
100
+ elif dataset_name == 'ukbb':
101
+ patch_size_list = [96, 384, 256]
102
+ else:
103
+ patch_size_list = [96, 384, 512]
104
+ patch_size = patch_size_list[2]
105
+
106
+ # path for dataset
107
+ if dataset_name == 'DRIVE' or dataset_name == 'LES' or dataset_name == 'hrf':
108
+ stride_height = 50
109
+ stride_width = 50
110
+ elif dataset_name == 'ukbb':
111
+ stride_height = 150
112
+ stride_width = 150
113
+ else:
114
+ use_CAM=True
115
+ stride_height = 150
116
+ stride_width = 150
117
+
118
+
119
+ n_classes = 3
120
+
121
+ model_step = 0
122
+
123
+
124
+ # use av_cross
125
+ use_av_cross = False
126
+
127
+ use_high_semantic = False
128
+ lambda_high = 1 # A,V,Vessel
129
+
130
+ # use global semantic
131
+ use_global_semantic = False
132
+ global_warmup_step = 0 if use_pretrained_G else 5000
133
+
134
+
135
+
136
+
137
+