import torch import os # Check GPU availability use_cuda = torch.cuda.is_available() gpu_ids = [0] if use_cuda else [] device = torch.device('cuda' if use_cuda else 'cpu') dataset_name = 'all' # DRIVE #dataset_name = 'LES' # LES # dataset_name = 'hrf' # HRF # dataset_name = 'ukbb' # UKBB # dataset_name = 'all' dataset = dataset_name max_step = 30000 # 30000 for ukbb batch_size = 8 # default: 4 print_iter = 100 # default: 100 display_iter = 100 # default: 100 save_iter = 5000 # default: 5000 first_display_metric_iter = max_step - save_iter # default: 25000 lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002 step_size = 7000 # 7000 for DRIVE lr_decay_gamma = 0.5 # default: 0.5 use_SGD = False # default:False input_nc = 3 ndf = 32 netD_type = 'basic' n_layers_D = 5 norm = 'instance' no_lsgan = False init_type = 'normal' init_gain = 0.02 use_sigmoid = no_lsgan use_noise_input_D = False use_dropout_D = False # torch.cuda.set_device(gpu_ids[0]) use_GAN = True # default: True # adam beta1 = 0.5 # settings for GAN loss num_classes_D = 1 lambda_GAN_D = 0.01 lambda_GAN_G = 0.01 lambda_GAN_gp = 100 lambda_BCE = 5 lambda_DICE = 5 input_nc_D = input_nc + 3 # settings for centerness use_centerness = True # default: True lambda_centerness = 1 center_loss_type = 'centerness' centerness_map_size = [128, 128] # pretrained model use_pretrained_G = True use_pretrained_D = False # model_path_pretrained_G = './log/patch_pretrain' model_path_pretrained_G = '' model_step_pretrained_G = 0 stride_height = 0 stride_width = 0 patch_size_list=[] use_CAM = False #use resize use_resize = False resize_w_h = (1920,512) def set_dataset(name): global dataset_name, model_path_pretrained_G, model_step_pretrained_G global stride_height, stride_width,patch_size,patch_size_list,dataset,use_CAM,use_resize,resize_w_h dataset_name = name dataset = name if dataset_name == 'DRIVE': model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)' model_step_pretrained_G = 6500 elif dataset_name == 'LES': model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)' model_step_pretrained_G = 0 elif dataset_name == 'hrf': model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)' model_step_pretrained_G = 1500 elif dataset_name == 'ukbb': model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)' model_step_pretrained_G = 5000 else: model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)' model_step_pretrained_G = 9000 if dataset_name == 'DRIVE': patch_size_list = [64, 128, 256] elif dataset_name == 'LES': patch_size_list = [96, 384, 256] elif dataset_name == 'hrf': patch_size_list = [64, 384, 256] elif dataset_name == 'ukbb': patch_size_list = [96, 384, 256] else: patch_size_list = [96, 384, 512] patch_size = patch_size_list[2] # path for dataset if dataset_name == 'DRIVE' or dataset_name == 'LES': stride_height = 50 stride_width = 50 elif dataset_name == 'ukbb' or dataset_name == 'hrf': use_CAM=False use_resize = True stride_height = 150 stride_width = 150 else: use_CAM=True use_resize = True stride_height = 150 stride_width = 150 n_classes = 3 model_step = 0 # use av_cross use_av_cross = False use_high_semantic = False lambda_high = 1 # A,V,Vessel # use global semantic in local, huggingface set false use_global_semantic = False global_warmup_step = 0 if use_pretrained_G else 5000