LoGoSAM_demo / validation.py
quandn2003's picture
Upload folder using huggingface_hub
427d150 verified
"""
Validation script
"""
import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
import torch.backends.cudnn as cudnn
import numpy as np
import matplotlib.pyplot as plt
from models.grid_proto_fewshot import FewShotSeg
from dataloaders.dev_customized_med import med_fewshot_val
from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset
from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset
from dataloaders.dataset_utils import DATASET_INFO, get_normalize_op
from dataloaders.niftiio import convert_to_sitk
import dataloaders.augutils as myaug
from util.metric import Metric
from util.consts import IMG_SIZE
from util.utils import cca, sliding_window_confidence_segmentation, plot_3d_bar_probabilities, save_pred_gt_fig, plot_heatmap_of_probs
from config_ssl_upload import ex
from tqdm import tqdm
import SimpleITK as sitk
from torchvision.utils import make_grid
from tqdm.auto import tqdm
from util.utils import set_seed, t2n, to01, compose_wt_simple
# config pre-trained model caching path
os.environ['TORCH_HOME'] = "./pretrained_model"
def test_time_training(_config, model, image, prediction):
model.train()
data_name = _config['dataset']
my_weight = compose_wt_simple(_config["use_wce"], data_name)
criterion = nn.CrossEntropyLoss(
ignore_index=_config['ignore_label'], weight=my_weight)
if _config['optim_type'] == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
elif _config['optim_type'] == 'adam':
optimizer = torch.optim.AdamW(
model.parameters(), lr=_config['lr'], eps=1e-5)
else:
raise NotImplementedError
optimizer.zero_grad()
scheduler = MultiStepLR(
optimizer, milestones=_config['lr_milestones'], gamma=_config['lr_step_gamma'])
tr_transforms = myaug.transform_with_label(
{'aug': myaug.get_aug(_config['which_aug'], _config['input_size'][0])})
comp = np.concatenate([image.transpose(1, 2, 0), prediction[None,...].transpose(1,2,0)], axis= -1)
print("Test Time Training...")
pbar = tqdm(range(_config['n_steps']))
for idx in pbar:
query_image, query_label = tr_transforms(comp, c_img=image.shape[0], c_label=1, nclass=2, use_onehot=False)
support_image, support_label = tr_transforms(comp, c_img=image.shape[0], c_label=1, nclass=2, use_onehot=False)
query_label = torch.from_numpy(query_label.transpose(2,1,0)).cuda().long()
query_images = [torch.from_numpy(query_image.transpose(2, 1, 0)).unsqueeze(0).cuda().float().requires_grad_(True)]
support_fg_mask = [[torch.from_numpy(support_label.transpose(2, 1, 0)).cuda().float().requires_grad_(True)]]
support_bg_mask = [[torch.from_numpy(1 - support_label.transpose(2, 1, 0)).cuda().float().requires_grad_(True)]]
support_images = [[torch.from_numpy(support_image.transpose(2, 1, 0)).unsqueeze(0).cuda().float().requires_grad_(True)]]
# fig, ax = plt.subplots(1, 2)
# ax[0].imshow(query_images[0][0,0].cpu().numpy())
# ax[1].imshow(support_image[...,0])
# ax[1].imshow(support_label[...,0], alpha=0.5)
# fig.savefig("debug/query_support_ttt.png")
out = model(support_images, support_fg_mask, support_bg_mask, query_images, isval=False, val_wsize=None)
query_pred, align_loss, _, _, _, _, _ = out
# fig, ax = plt.subplots(1, 2)
# pred = np.array(query_pred.argmax(dim=1)[0].cpu())
# ax[0].imshow(query_images[0][0,0].cpu().numpy())
# ax[0].imshow(pred, alpha=0.5)
# ax[1].imshow(support_image[...,0])
# ax[1].imshow(support_label[...,0], alpha=0.5)
# fig.savefig("debug/ttt.png")
loss = 0.0
loss += criterion(query_pred.float(), query_label.long())
loss += align_loss
loss.backward()
if (idx + 1) % _config['grad_accumulation_steps'] == 0:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
pbar.set_postfix(loss=f"{loss.item():.4f}")
model.eval()
return model
@ex.automain
def main(_run, _config, _log):
if _run.observers:
os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True)
for source_file, _ in _run.experiment_info['sources']:
os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
exist_ok=True)
_run.observers[0].save_file(source_file, f'source/{source_file}')
shutil.rmtree(f'{_run.observers[0].basedir}/_sources')
torch.cuda.set_device(device=_config['gpu_id'])
torch.set_num_threads(1)
_log.info(f'###### Reload model {_config["reload_model_path"]} ######')
model = FewShotSeg(image_size=_config['input_size'][0],
pretrained_path=_config['reload_model_path'], cfg=_config['model'])
model = model.cuda()
model.eval()
_log.info('###### Load data ######')
# Training set
data_name = _config['dataset']
if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672':
baseset_name = 'SABS'
max_label = 13
elif data_name == 'C0_Superpix':
raise NotImplementedError
baseset_name = 'C0'
max_label = 3
elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672':
baseset_name = 'CHAOST2'
max_label = 4
elif 'lits' in data_name.lower():
baseset_name = 'LITS17'
max_label = 4
else:
raise ValueError(f'Dataset: {data_name} not found')
test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP']['pa_all'] - \
DATASET_INFO[baseset_name]['LABEL_GROUP'][_config["label_sets"]]
_log.info(
f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######')
_log.info(
f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######')
if baseset_name == 'SABS':
tr_parent = SuperpixelDataset( # base dataset
which_dataset=baseset_name,
base_dir=_config['path'][data_name]['data_dir'],
idx_split=_config['eval_fold'],
mode='val', # 'train',
# dummy entry for superpixel dataset
min_fg=str(_config["min_fg_data"]),
image_size=_config['input_size'][0],
transforms=None,
nsup=_config['task']['n_shots'],
scan_per_load=_config['scan_per_load'],
exclude_list=_config["exclude_cls_list"],
superpix_scale=_config["superpix_scale"],
fix_length=_config["max_iters_per_load"] if (data_name == 'C0_Superpix') or (
data_name == 'CHAOST2_Superpix') else None,
use_clahe=_config['use_clahe'],
norm_mean=0.18792 * 256 if baseset_name == 'LITS17' else None,
norm_std=0.25886 * 256 if baseset_name == 'LITS17' else None
)
norm_func = tr_parent.norm_func
else:
norm_func = get_normalize_op(modality='MR', fids=None)
te_dataset, te_parent = med_fewshot_val(
dataset_name=baseset_name,
base_dir=_config['path'][data_name]['data_dir'],
idx_split=_config['eval_fold'],
scan_per_load=_config['scan_per_load'],
act_labels=test_labels,
npart=_config['task']['npart'],
nsup=_config['task']['n_shots'],
extern_normalize_func=norm_func,
image_size=_config["input_size"][0],
use_clahe=_config['use_clahe'],
use_3_slices=_config["use_3_slices"]
)
# dataloaders
testloader = DataLoader(
te_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=False,
drop_last=False
)
_log.info('###### Set validation nodes ######')
mar_val_metric_node = Metric(max_label=max_label, n_scans=len(
te_dataset.dataset.pid_curr_load) - _config['task']['n_shots'])
_log.info('###### Starting validation ######')
mar_val_metric_node.reset()
if _config["sliding_window_confidence_segmentation"]:
print("Using sliding window confidence segmentation") # TODO delete this
save_pred_buffer = {} # indexed by class
for curr_lb in test_labels:
te_dataset.set_curr_cls(curr_lb)
support_batched = te_parent.get_support(curr_class=curr_lb, class_idx=[
curr_lb], scan_idx=_config["support_idx"], npart=_config['task']['npart'])
# way(1 for now) x part x shot x 3 x H x W] #
support_images = [[shot.cuda() for shot in way]
for way in support_batched['support_images']] # way x part x [shot x C x H x W]
suffix = 'mask'
support_fg_mask = [[shot[f'fg_{suffix}'].float().cuda() for shot in way]
for way in support_batched['support_mask']]
support_bg_mask = [[shot[f'bg_{suffix}'].float().cuda() for shot in way]
for way in support_batched['support_mask']]
curr_scan_count = -1 # counting for current scan
_lb_buffer = {} # indexed by scan
_lb_vis_buffer = {}
last_qpart = 0 # used as indicator for adding result to buffer
for idx, sample_batched in enumerate(tqdm(testloader)):
# we assume batch size for query is 1
_scan_id = sample_batched["scan_id"][0]
if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query
continue
if sample_batched["is_start"]:
ii = 0
curr_scan_count += 1
print(
f"Processing scan {curr_scan_count + 1} / {len(te_dataset.dataset.pid_curr_load)}")
_scan_id = sample_batched["scan_id"][0]
outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"]
# original image read by itk: Z, H, W, in prediction we use H, W, Z
outsize = (_config['input_size'][0],
_config['input_size'][1], outsize[0])
_pred = np.zeros(outsize)
_pred.fill(np.nan)
# assign proto shows in the query image which proto is assigned to each pixel, proto_grid is the ids of the prototypes in the support image used, support_images are the 3 support images, support_img_parts are the parts of the support images used for each query image
_vis = {'assigned_proto': [None] * _pred.shape[-1], 'proto_grid': [None] * _pred.shape[-1],
'support_images': support_images, 'support_img_parts': [None] * _pred.shape[-1]}
# the chunck of query, for assignment with support
q_part = sample_batched["part_assign"]
query_images = [sample_batched['image'].cuda()]
query_labels = torch.cat(
[sample_batched['label'].cuda()], dim=0)
if 1 not in query_labels and not sample_batched["is_end"] and _config["skip_no_organ_slices"]:
ii += 1
continue
# [way, [part, [shot x C x H x W]]] ->
# way(1) x shot x [B(1) x C x H x W]
sup_img_part = [[shot_tensor.unsqueeze(
0) for shot_tensor in support_images[0][q_part]]]
sup_fgm_part = [[shot_tensor.unsqueeze(
0) for shot_tensor in support_fg_mask[0][q_part]]]
sup_bgm_part = [[shot_tensor.unsqueeze(
0) for shot_tensor in support_bg_mask[0][q_part]]]
# query_pred_logits, _, _, assign_mats, proto_grid, _, _ = model(
# sup_img_part, sup_fgm_part, sup_bgm_part, query_images, isval=True, val_wsize=_config["val_wsize"], show_viz=True)
with torch.no_grad():
out = model(sup_img_part, sup_fgm_part, sup_bgm_part,
query_images, isval=True, val_wsize=_config["val_wsize"])
query_pred_logits, _, _, assign_mats, proto_grid, _, _ = out
pred = np.array(query_pred_logits.argmax(dim=1)[0].cpu())
if _config["ttt"]:
state_dict = model.state_dict()
model = test_time_training(_config, model, sample_batched['image'].numpy()[0], pred)
out = model(sup_img_part, sup_fgm_part, sup_bgm_part,
query_images, isval=True, val_wsize=_config["val_wsize"])
query_pred_logits, _, _, assign_mats, proto_grid, _, _ = out
pred = np.array(query_pred_logits.argmax(dim=1)[0].cpu())
if _config["reset_after_slice"]:
model.load_state_dict(state_dict)
query_pred = query_pred_logits.argmax(dim=1).cpu()
query_pred = F.interpolate(query_pred.unsqueeze(
0).float(), size=query_labels.shape[-2:], mode='nearest').squeeze(0).long().numpy()[0]
if _config["debug"]:
save_pred_gt_fig(query_images, query_pred, query_labels, sup_img_part[0], sup_fgm_part[0][0],
f'debug/preds/scan_{_scan_id}_label_{curr_lb}_{idx}_gt_vs_pred.png')
if _config['do_cca']:
query_pred = cca(query_pred, query_pred_logits)
if _config["debug"]:
save_pred_gt_fig(query_images, query_pred, query_labels,
f'debug/scan_{_scan_id}_label_{curr_lb}_{idx}_gt_vs_pred_after_cca.png')
_pred[..., ii] = query_pred.copy()
# _vis['assigned_proto'][ii] = assign_mats
# _vis['proto_grid'][ii] = proto_grid.cpu()
# proto_ids = torch.unique(proto_grid)
# _vis['support_img_parts'][ii] = q_part
if (sample_batched["z_id"] - sample_batched["z_max"] <= _config['z_margin']) and (sample_batched["z_id"] - sample_batched["z_min"] >= -1 * _config['z_margin']) and not sample_batched["is_end"]:
mar_val_metric_node.record(query_pred, np.array(
query_labels[0].cpu()), labels=[curr_lb], n_scan=curr_scan_count)
else:
pass
ii += 1
# now check data format
if sample_batched["is_end"]:
if _config['dataset'] != 'C0':
_lb_buffer[_scan_id] = _pred.transpose(
2, 0, 1) # H, W, Z -> to Z H W
else:
_lb_buffer[_scan_id] = _pred
# _lb_vis_buffer[_scan_id] = _vis
save_pred_buffer[str(curr_lb)] = _lb_buffer
# save results
for curr_lb, _preds in save_pred_buffer.items():
for _scan_id, _pred in _preds.items():
_pred *= float(curr_lb)
itk_pred = convert_to_sitk(
_pred, te_dataset.dataset.info_by_scan[_scan_id])
fid = os.path.join(
f'{_run.observers[0].dir}/interm_preds', f'scan_{_scan_id}_label_{curr_lb}.nii.gz')
sitk.WriteImage(itk_pred, fid, True)
_log.info(f'###### {fid} has been saved ######')
# compute dice scores by scan
m_classDice, _, m_meanDice, _, m_rawDice = mar_val_metric_node.get_mDice(
labels=sorted(test_labels), n_scan=None, give_raw=True)
m_classPrec, _, m_meanPrec, _, m_classRec, _, m_meanRec, _, m_rawPrec, m_rawRec = mar_val_metric_node.get_mPrecRecall(
labels=sorted(test_labels), n_scan=None, give_raw=True)
mar_val_metric_node.reset() # reset this calculation node
# write validation result to log file
_run.log_scalar('mar_val_batches_classDice', m_classDice.tolist())
_run.log_scalar('mar_val_batches_meanDice', m_meanDice.tolist())
_run.log_scalar('mar_val_batches_rawDice', m_rawDice.tolist())
_run.log_scalar('mar_val_batches_classPrec', m_classPrec.tolist())
_run.log_scalar('mar_val_batches_meanPrec', m_meanPrec.tolist())
_run.log_scalar('mar_val_batches_rawPrec', m_rawPrec.tolist())
_run.log_scalar('mar_val_batches_classRec', m_classRec.tolist())
_run.log_scalar('mar_val_al_batches_meanRec', m_meanRec.tolist())
_run.log_scalar('mar_val_al_batches_rawRec', m_rawRec.tolist())
_log.info(f'mar_val batches classDice: {m_classDice}')
_log.info(f'mar_val batches meanDice: {m_meanDice}')
_log.info(f'mar_val batches classPrec: {m_classPrec}')
_log.info(f'mar_val batches meanPrec: {m_meanPrec}')
_log.info(f'mar_val batches classRec: {m_classRec}')
_log.info(f'mar_val batches meanRec: {m_meanRec}')
print("============ ============")
_log.info(f'End of validation')
return 1