Spaces:
Sleeping
Sleeping
""" | |
Validation script | |
""" | |
import os | |
import shutil | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim | |
from torch.utils.data import DataLoader | |
from torch.optim.lr_scheduler import MultiStepLR | |
import torch.backends.cudnn as cudnn | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from models.grid_proto_fewshot import FewShotSeg | |
from dataloaders.dev_customized_med import med_fewshot_val | |
from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset | |
from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset | |
from dataloaders.dataset_utils import DATASET_INFO, get_normalize_op | |
from dataloaders.niftiio import convert_to_sitk | |
import dataloaders.augutils as myaug | |
from util.metric import Metric | |
from util.consts import IMG_SIZE | |
from util.utils import cca, sliding_window_confidence_segmentation, plot_3d_bar_probabilities, save_pred_gt_fig, plot_heatmap_of_probs | |
from config_ssl_upload import ex | |
from tqdm import tqdm | |
import SimpleITK as sitk | |
from torchvision.utils import make_grid | |
from tqdm.auto import tqdm | |
from util.utils import set_seed, t2n, to01, compose_wt_simple | |
# config pre-trained model caching path | |
os.environ['TORCH_HOME'] = "./pretrained_model" | |
def test_time_training(_config, model, image, prediction): | |
model.train() | |
data_name = _config['dataset'] | |
my_weight = compose_wt_simple(_config["use_wce"], data_name) | |
criterion = nn.CrossEntropyLoss( | |
ignore_index=_config['ignore_label'], weight=my_weight) | |
if _config['optim_type'] == 'sgd': | |
optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) | |
elif _config['optim_type'] == 'adam': | |
optimizer = torch.optim.AdamW( | |
model.parameters(), lr=_config['lr'], eps=1e-5) | |
else: | |
raise NotImplementedError | |
optimizer.zero_grad() | |
scheduler = MultiStepLR( | |
optimizer, milestones=_config['lr_milestones'], gamma=_config['lr_step_gamma']) | |
tr_transforms = myaug.transform_with_label( | |
{'aug': myaug.get_aug(_config['which_aug'], _config['input_size'][0])}) | |
comp = np.concatenate([image.transpose(1, 2, 0), prediction[None,...].transpose(1,2,0)], axis= -1) | |
print("Test Time Training...") | |
pbar = tqdm(range(_config['n_steps'])) | |
for idx in pbar: | |
query_image, query_label = tr_transforms(comp, c_img=image.shape[0], c_label=1, nclass=2, use_onehot=False) | |
support_image, support_label = tr_transforms(comp, c_img=image.shape[0], c_label=1, nclass=2, use_onehot=False) | |
query_label = torch.from_numpy(query_label.transpose(2,1,0)).cuda().long() | |
query_images = [torch.from_numpy(query_image.transpose(2, 1, 0)).unsqueeze(0).cuda().float().requires_grad_(True)] | |
support_fg_mask = [[torch.from_numpy(support_label.transpose(2, 1, 0)).cuda().float().requires_grad_(True)]] | |
support_bg_mask = [[torch.from_numpy(1 - support_label.transpose(2, 1, 0)).cuda().float().requires_grad_(True)]] | |
support_images = [[torch.from_numpy(support_image.transpose(2, 1, 0)).unsqueeze(0).cuda().float().requires_grad_(True)]] | |
# fig, ax = plt.subplots(1, 2) | |
# ax[0].imshow(query_images[0][0,0].cpu().numpy()) | |
# ax[1].imshow(support_image[...,0]) | |
# ax[1].imshow(support_label[...,0], alpha=0.5) | |
# fig.savefig("debug/query_support_ttt.png") | |
out = model(support_images, support_fg_mask, support_bg_mask, query_images, isval=False, val_wsize=None) | |
query_pred, align_loss, _, _, _, _, _ = out | |
# fig, ax = plt.subplots(1, 2) | |
# pred = np.array(query_pred.argmax(dim=1)[0].cpu()) | |
# ax[0].imshow(query_images[0][0,0].cpu().numpy()) | |
# ax[0].imshow(pred, alpha=0.5) | |
# ax[1].imshow(support_image[...,0]) | |
# ax[1].imshow(support_label[...,0], alpha=0.5) | |
# fig.savefig("debug/ttt.png") | |
loss = 0.0 | |
loss += criterion(query_pred.float(), query_label.long()) | |
loss += align_loss | |
loss.backward() | |
if (idx + 1) % _config['grad_accumulation_steps'] == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
scheduler.step() | |
pbar.set_postfix(loss=f"{loss.item():.4f}") | |
model.eval() | |
return model | |
def main(_run, _config, _log): | |
if _run.observers: | |
os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True) | |
for source_file, _ in _run.experiment_info['sources']: | |
os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), | |
exist_ok=True) | |
_run.observers[0].save_file(source_file, f'source/{source_file}') | |
shutil.rmtree(f'{_run.observers[0].basedir}/_sources') | |
torch.cuda.set_device(device=_config['gpu_id']) | |
torch.set_num_threads(1) | |
_log.info(f'###### Reload model {_config["reload_model_path"]} ######') | |
model = FewShotSeg(image_size=_config['input_size'][0], | |
pretrained_path=_config['reload_model_path'], cfg=_config['model']) | |
model = model.cuda() | |
model.eval() | |
_log.info('###### Load data ######') | |
# Training set | |
data_name = _config['dataset'] | |
if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672': | |
baseset_name = 'SABS' | |
max_label = 13 | |
elif data_name == 'C0_Superpix': | |
raise NotImplementedError | |
baseset_name = 'C0' | |
max_label = 3 | |
elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672': | |
baseset_name = 'CHAOST2' | |
max_label = 4 | |
elif 'lits' in data_name.lower(): | |
baseset_name = 'LITS17' | |
max_label = 4 | |
else: | |
raise ValueError(f'Dataset: {data_name} not found') | |
test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP']['pa_all'] - \ | |
DATASET_INFO[baseset_name]['LABEL_GROUP'][_config["label_sets"]] | |
_log.info( | |
f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######') | |
_log.info( | |
f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######') | |
if baseset_name == 'SABS': | |
tr_parent = SuperpixelDataset( # base dataset | |
which_dataset=baseset_name, | |
base_dir=_config['path'][data_name]['data_dir'], | |
idx_split=_config['eval_fold'], | |
mode='val', # 'train', | |
# dummy entry for superpixel dataset | |
min_fg=str(_config["min_fg_data"]), | |
image_size=_config['input_size'][0], | |
transforms=None, | |
nsup=_config['task']['n_shots'], | |
scan_per_load=_config['scan_per_load'], | |
exclude_list=_config["exclude_cls_list"], | |
superpix_scale=_config["superpix_scale"], | |
fix_length=_config["max_iters_per_load"] if (data_name == 'C0_Superpix') or ( | |
data_name == 'CHAOST2_Superpix') else None, | |
use_clahe=_config['use_clahe'], | |
norm_mean=0.18792 * 256 if baseset_name == 'LITS17' else None, | |
norm_std=0.25886 * 256 if baseset_name == 'LITS17' else None | |
) | |
norm_func = tr_parent.norm_func | |
else: | |
norm_func = get_normalize_op(modality='MR', fids=None) | |
te_dataset, te_parent = med_fewshot_val( | |
dataset_name=baseset_name, | |
base_dir=_config['path'][data_name]['data_dir'], | |
idx_split=_config['eval_fold'], | |
scan_per_load=_config['scan_per_load'], | |
act_labels=test_labels, | |
npart=_config['task']['npart'], | |
nsup=_config['task']['n_shots'], | |
extern_normalize_func=norm_func, | |
image_size=_config["input_size"][0], | |
use_clahe=_config['use_clahe'], | |
use_3_slices=_config["use_3_slices"] | |
) | |
# dataloaders | |
testloader = DataLoader( | |
te_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=1, | |
pin_memory=False, | |
drop_last=False | |
) | |
_log.info('###### Set validation nodes ######') | |
mar_val_metric_node = Metric(max_label=max_label, n_scans=len( | |
te_dataset.dataset.pid_curr_load) - _config['task']['n_shots']) | |
_log.info('###### Starting validation ######') | |
mar_val_metric_node.reset() | |
if _config["sliding_window_confidence_segmentation"]: | |
print("Using sliding window confidence segmentation") # TODO delete this | |
save_pred_buffer = {} # indexed by class | |
for curr_lb in test_labels: | |
te_dataset.set_curr_cls(curr_lb) | |
support_batched = te_parent.get_support(curr_class=curr_lb, class_idx=[ | |
curr_lb], scan_idx=_config["support_idx"], npart=_config['task']['npart']) | |
# way(1 for now) x part x shot x 3 x H x W] # | |
support_images = [[shot.cuda() for shot in way] | |
for way in support_batched['support_images']] # way x part x [shot x C x H x W] | |
suffix = 'mask' | |
support_fg_mask = [[shot[f'fg_{suffix}'].float().cuda() for shot in way] | |
for way in support_batched['support_mask']] | |
support_bg_mask = [[shot[f'bg_{suffix}'].float().cuda() for shot in way] | |
for way in support_batched['support_mask']] | |
curr_scan_count = -1 # counting for current scan | |
_lb_buffer = {} # indexed by scan | |
_lb_vis_buffer = {} | |
last_qpart = 0 # used as indicator for adding result to buffer | |
for idx, sample_batched in enumerate(tqdm(testloader)): | |
# we assume batch size for query is 1 | |
_scan_id = sample_batched["scan_id"][0] | |
if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query | |
continue | |
if sample_batched["is_start"]: | |
ii = 0 | |
curr_scan_count += 1 | |
print( | |
f"Processing scan {curr_scan_count + 1} / {len(te_dataset.dataset.pid_curr_load)}") | |
_scan_id = sample_batched["scan_id"][0] | |
outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"] | |
# original image read by itk: Z, H, W, in prediction we use H, W, Z | |
outsize = (_config['input_size'][0], | |
_config['input_size'][1], outsize[0]) | |
_pred = np.zeros(outsize) | |
_pred.fill(np.nan) | |
# assign proto shows in the query image which proto is assigned to each pixel, proto_grid is the ids of the prototypes in the support image used, support_images are the 3 support images, support_img_parts are the parts of the support images used for each query image | |
_vis = {'assigned_proto': [None] * _pred.shape[-1], 'proto_grid': [None] * _pred.shape[-1], | |
'support_images': support_images, 'support_img_parts': [None] * _pred.shape[-1]} | |
# the chunck of query, for assignment with support | |
q_part = sample_batched["part_assign"] | |
query_images = [sample_batched['image'].cuda()] | |
query_labels = torch.cat( | |
[sample_batched['label'].cuda()], dim=0) | |
if 1 not in query_labels and not sample_batched["is_end"] and _config["skip_no_organ_slices"]: | |
ii += 1 | |
continue | |
# [way, [part, [shot x C x H x W]]] -> | |
# way(1) x shot x [B(1) x C x H x W] | |
sup_img_part = [[shot_tensor.unsqueeze( | |
0) for shot_tensor in support_images[0][q_part]]] | |
sup_fgm_part = [[shot_tensor.unsqueeze( | |
0) for shot_tensor in support_fg_mask[0][q_part]]] | |
sup_bgm_part = [[shot_tensor.unsqueeze( | |
0) for shot_tensor in support_bg_mask[0][q_part]]] | |
# query_pred_logits, _, _, assign_mats, proto_grid, _, _ = model( | |
# sup_img_part, sup_fgm_part, sup_bgm_part, query_images, isval=True, val_wsize=_config["val_wsize"], show_viz=True) | |
with torch.no_grad(): | |
out = model(sup_img_part, sup_fgm_part, sup_bgm_part, | |
query_images, isval=True, val_wsize=_config["val_wsize"]) | |
query_pred_logits, _, _, assign_mats, proto_grid, _, _ = out | |
pred = np.array(query_pred_logits.argmax(dim=1)[0].cpu()) | |
if _config["ttt"]: | |
state_dict = model.state_dict() | |
model = test_time_training(_config, model, sample_batched['image'].numpy()[0], pred) | |
out = model(sup_img_part, sup_fgm_part, sup_bgm_part, | |
query_images, isval=True, val_wsize=_config["val_wsize"]) | |
query_pred_logits, _, _, assign_mats, proto_grid, _, _ = out | |
pred = np.array(query_pred_logits.argmax(dim=1)[0].cpu()) | |
if _config["reset_after_slice"]: | |
model.load_state_dict(state_dict) | |
query_pred = query_pred_logits.argmax(dim=1).cpu() | |
query_pred = F.interpolate(query_pred.unsqueeze( | |
0).float(), size=query_labels.shape[-2:], mode='nearest').squeeze(0).long().numpy()[0] | |
if _config["debug"]: | |
save_pred_gt_fig(query_images, query_pred, query_labels, sup_img_part[0], sup_fgm_part[0][0], | |
f'debug/preds/scan_{_scan_id}_label_{curr_lb}_{idx}_gt_vs_pred.png') | |
if _config['do_cca']: | |
query_pred = cca(query_pred, query_pred_logits) | |
if _config["debug"]: | |
save_pred_gt_fig(query_images, query_pred, query_labels, | |
f'debug/scan_{_scan_id}_label_{curr_lb}_{idx}_gt_vs_pred_after_cca.png') | |
_pred[..., ii] = query_pred.copy() | |
# _vis['assigned_proto'][ii] = assign_mats | |
# _vis['proto_grid'][ii] = proto_grid.cpu() | |
# proto_ids = torch.unique(proto_grid) | |
# _vis['support_img_parts'][ii] = q_part | |
if (sample_batched["z_id"] - sample_batched["z_max"] <= _config['z_margin']) and (sample_batched["z_id"] - sample_batched["z_min"] >= -1 * _config['z_margin']) and not sample_batched["is_end"]: | |
mar_val_metric_node.record(query_pred, np.array( | |
query_labels[0].cpu()), labels=[curr_lb], n_scan=curr_scan_count) | |
else: | |
pass | |
ii += 1 | |
# now check data format | |
if sample_batched["is_end"]: | |
if _config['dataset'] != 'C0': | |
_lb_buffer[_scan_id] = _pred.transpose( | |
2, 0, 1) # H, W, Z -> to Z H W | |
else: | |
_lb_buffer[_scan_id] = _pred | |
# _lb_vis_buffer[_scan_id] = _vis | |
save_pred_buffer[str(curr_lb)] = _lb_buffer | |
# save results | |
for curr_lb, _preds in save_pred_buffer.items(): | |
for _scan_id, _pred in _preds.items(): | |
_pred *= float(curr_lb) | |
itk_pred = convert_to_sitk( | |
_pred, te_dataset.dataset.info_by_scan[_scan_id]) | |
fid = os.path.join( | |
f'{_run.observers[0].dir}/interm_preds', f'scan_{_scan_id}_label_{curr_lb}.nii.gz') | |
sitk.WriteImage(itk_pred, fid, True) | |
_log.info(f'###### {fid} has been saved ######') | |
# compute dice scores by scan | |
m_classDice, _, m_meanDice, _, m_rawDice = mar_val_metric_node.get_mDice( | |
labels=sorted(test_labels), n_scan=None, give_raw=True) | |
m_classPrec, _, m_meanPrec, _, m_classRec, _, m_meanRec, _, m_rawPrec, m_rawRec = mar_val_metric_node.get_mPrecRecall( | |
labels=sorted(test_labels), n_scan=None, give_raw=True) | |
mar_val_metric_node.reset() # reset this calculation node | |
# write validation result to log file | |
_run.log_scalar('mar_val_batches_classDice', m_classDice.tolist()) | |
_run.log_scalar('mar_val_batches_meanDice', m_meanDice.tolist()) | |
_run.log_scalar('mar_val_batches_rawDice', m_rawDice.tolist()) | |
_run.log_scalar('mar_val_batches_classPrec', m_classPrec.tolist()) | |
_run.log_scalar('mar_val_batches_meanPrec', m_meanPrec.tolist()) | |
_run.log_scalar('mar_val_batches_rawPrec', m_rawPrec.tolist()) | |
_run.log_scalar('mar_val_batches_classRec', m_classRec.tolist()) | |
_run.log_scalar('mar_val_al_batches_meanRec', m_meanRec.tolist()) | |
_run.log_scalar('mar_val_al_batches_rawRec', m_rawRec.tolist()) | |
_log.info(f'mar_val batches classDice: {m_classDice}') | |
_log.info(f'mar_val batches meanDice: {m_meanDice}') | |
_log.info(f'mar_val batches classPrec: {m_classPrec}') | |
_log.info(f'mar_val batches meanPrec: {m_meanPrec}') | |
_log.info(f'mar_val batches classRec: {m_classRec}') | |
_log.info(f'mar_val batches meanRec: {m_meanRec}') | |
print("============ ============") | |
_log.info(f'End of validation') | |
return 1 | |