Spaces:
Sleeping
Sleeping
File size: 6,571 Bytes
427d150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
Experiment configuration file
Extended from config file from original PANet Repository
"""
import os
import re
import glob
import itertools
import sacred
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
from platform import node
from datetime import datetime
from util.consts import IMG_SIZE
sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
sacred.SETTINGS.CAPTURE_MODE = 'no'
ex = Experiment('mySSL')
ex.captured_out_filter = apply_backspaces_and_linefeeds
source_folders = ['.', './dataloaders', './models', './util']
sources_to_save = list(itertools.chain.from_iterable(
[glob.glob(f'{folder}/*.py') for folder in source_folders]))
for source_file in sources_to_save:
ex.add_source_file(source_file)
@ex.config
def cfg():
"""Default configurations"""
seed = 1234
gpu_id = 0
mode = 'train' # for now only allows 'train'
do_validation=False
num_workers = 4 # 0 for debugging.
dataset = 'CHAOST2' # i.e. abdominal MRI
use_coco_init = True # initialize backbone with MS_COCO initialization. Anyway coco does not contain medical images
### Training
n_steps = 100100
batch_size = 1
lr_milestones = [ (ii + 1) * 1000 for ii in range(n_steps // 1000 - 1)]
lr_step_gamma = 0.95
ignore_label = 255
print_interval = 100
save_snapshot_every = 25000
max_iters_per_load = 1000 # epoch size, interval for reloading the dataset
epochs=1
scan_per_load = -1 # numbers of 3d scans per load for saving memory. If -1, load the entire dataset to the memory
which_aug = 'sabs_aug' # standard data augmentation with intensity and geometric transforms
input_size = (IMG_SIZE, IMG_SIZE)
min_fg_data='100' # when training with manual annotations, indicating number of foreground pixels in a single class single slice. This empirically stablizes the training process
label_sets = 0 # which group of labels taking as training (the rest are for testing)
curr_cls = "" # choose between rk, lk, spleen and liver
exclude_cls_list = [2, 3] # testing classes to be excluded in training. Set to [] if testing under setting 1
usealign = True # see vanilla PANet
use_wce = True
use_dinov2_loss = False
dice_loss = False
### Validation
z_margin = 0
eval_fold = 0 # which fold for 5 fold cross validation
support_idx=[-1] # indicating which scan is used as support in testing.
val_wsize=2 # L_H, L_W in testing
n_sup_part = 3 # number of chuncks in testing
use_clahe = False
use_slice_adapter = False
adapter_layers=3
debug=True
skip_no_organ_slices=True
# Network
modelname = 'dlfcn_res101' # resnet 101 backbone from torchvision fcn-deeplab
clsname = None #
reload_model_path = None # path for reloading a trained model (overrides ms-coco initialization)
proto_grid_size = 8 # L_H, L_W = (32, 32) / 8 = (4, 4) in training
feature_hw = [input_size[0]//8, input_size[0]//8] # feature map size, should couple this with backbone in future
lora = 0
use_3_slices=False
do_cca=False
use_edge_detector=False
finetune_on_support=False
sliding_window_confidence_segmentation=False
finetune_model_on_single_slice=False
online_finetuning=True
use_bbox=True # for SAM
use_points=True # for SAM
use_mask=False # for SAM
base_model="alpnet" # or "SAM"
# SSL
superpix_scale = 'MIDDLE' #MIDDLE/ LARGE
use_pos_enc=False
support_txt_file = None # path to a txt file containing support slices
augment_support_set=False
coarse_pred_only=False # for ProtoSAM
point_mode="both" # for ProtoSAM, choose: both, conf, centroid
use_neg_points=False
n_support=1 # num support images
protosam_sam_ver="sam_h" # or medsam
grad_accumulation_steps=1
ttt=False
reset_after_slice=True # for TTT, if to reset the model after finetuning on each slice
model = {
'align': usealign,
'dinov2_loss': use_dinov2_loss,
'use_coco_init': use_coco_init,
'which_model': modelname,
'cls_name': clsname,
'proto_grid_size' : proto_grid_size,
'feature_hw': feature_hw,
'reload_model_path': reload_model_path,
'lora': lora,
'use_slice_adapter': use_slice_adapter,
'adapter_layers': adapter_layers,
'debug': debug,
'use_pos_enc': use_pos_enc
}
task = {
'n_ways': 1,
'n_shots': 1,
'n_queries': 1,
'npart': n_sup_part
}
optim_type = 'sgd'
lr=1e-3
momentum=0.9
weight_decay=0.0005
optim = {
'lr': lr,
'momentum': momentum,
'weight_decay': weight_decay
}
exp_prefix = ''
exp_str = '_'.join(
[exp_prefix]
+ [dataset,]
+ [f'sets_{label_sets}_{task["n_shots"]}shot'])
path = {
'log_dir': './runs',
'SABS':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"
},
'SABS_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"
},
'SABS_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"
},
'C0':{'data_dir': "feed your dataset path here"
},
'CHAOST2':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"
},
'CHAOST2_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"
},
'SABS_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"},
'C0_Superpix':{'data_dir': "feed your dataset path here"},
'CHAOST2_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"},
'CHAOST2_Superpix_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"},
'SABS_Superpix_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"},
'SABS_Superpix_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"},
}
@ex.config_hook
def add_observer(config, command_name, logger):
"""A hook fucntion to add observer"""
exp_name = f'{ex.path}_{config["exp_str"]}'
observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
ex.observers.append(observer)
return config
|