Spaces:
Running
Running
File size: 3,555 Bytes
6c0075d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import os
# Check GPU availability
use_cuda = torch.cuda.is_available()
gpu_ids = [0] if use_cuda else []
device = torch.device('cuda' if use_cuda else 'cpu')
dataset_name = 'all' # DRIVE
#dataset_name = 'LES' # LES
# dataset_name = 'hrf' # HRF
# dataset_name = 'ukbb' # UKBB
# dataset_name = 'all'
dataset = dataset_name
max_step = 30000 # 30000 for ukbb
batch_size = 8 # default: 4
print_iter = 100 # default: 100
display_iter = 100 # default: 100
save_iter = 5000 # default: 5000
first_display_metric_iter = max_step - save_iter # default: 25000
lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002
step_size = 7000 # 7000 for DRIVE
lr_decay_gamma = 0.5 # default: 0.5
use_SGD = False # default:False
input_nc = 3
ndf = 32
netD_type = 'basic'
n_layers_D = 5
norm = 'instance'
no_lsgan = False
init_type = 'normal'
init_gain = 0.02
use_sigmoid = no_lsgan
use_noise_input_D = False
use_dropout_D = False
# torch.cuda.set_device(gpu_ids[0])
use_GAN = True # default: True
# adam
beta1 = 0.5
# settings for GAN loss
num_classes_D = 1
lambda_GAN_D = 0.01
lambda_GAN_G = 0.01
lambda_GAN_gp = 100
lambda_BCE = 5
lambda_DICE = 5
input_nc_D = input_nc + 3
# settings for centerness
use_centerness = True # default: True
lambda_centerness = 1
center_loss_type = 'centerness'
centerness_map_size = [128, 128]
# pretrained model
use_pretrained_G = True
use_pretrained_D = False
# model_path_pretrained_G = './log/patch_pretrain'
model_path_pretrained_G = ''
model_step_pretrained_G = 0
stride_height = 0
stride_width = 0
patch_size_list=[]
def set_dataset(name):
global dataset_name, model_path_pretrained_G, model_step_pretrained_G
global stride_height, stride_width,patch_size,patch_size_list,dataset
dataset_name = name
dataset = name
if dataset_name == 'DRIVE':
model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)'
model_step_pretrained_G = 6500
elif dataset_name == 'LES':
model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)'
model_step_pretrained_G = 0
elif dataset_name == 'hrf':
model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)'
model_step_pretrained_G = 1500
elif dataset_name == 'ukbb':
model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)'
model_step_pretrained_G = 5000
else:
model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)'
model_step_pretrained_G = 9000
if dataset_name == 'DRIVE':
patch_size_list = [64, 128, 256]
elif dataset_name == 'LES':
patch_size_list = [96, 384, 256]
elif dataset_name == 'hrf':
patch_size_list = [64, 384, 256]
elif dataset_name == 'ukbb':
patch_size_list = [96, 384, 256]
else:
patch_size_list = [96, 384, 512]
patch_size = patch_size_list[2]
# path for dataset
if dataset_name == 'DRIVE' or dataset_name == 'LES' or dataset_name == 'hrf':
stride_height = 50
stride_width = 50
else:
stride_height = 150
stride_width = 150
n_classes = 3
model_step = 0
# use CAM
use_CAM = False
#use resize
use_resize = True
resize_w_h = (1920,512)
# use av_cross
use_av_cross = False
use_high_semantic = False
lambda_high = 1 # A,V,Vessel
# use global semantic
use_global_semantic = False
global_warmup_step = 0 if use_pretrained_G else 5000
|