Spaces:
Sleeping
Sleeping
""" | |
Training the model | |
Extended from original implementation of ALPNet. | |
""" | |
from scipy.ndimage import distance_transform_edt as eucl_distance | |
import os | |
import shutil | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim | |
from torch.utils.data import DataLoader | |
from torch.optim.lr_scheduler import MultiStepLR | |
import numpy as np | |
from models.grid_proto_fewshot import FewShotSeg | |
from torch.utils.tensorboard import SummaryWriter | |
from dataloaders.dev_customized_med import med_fewshot | |
from dataloaders.GenericSuperDatasetv2 import SuperpixelDataset | |
import dataloaders.augutils as myaug | |
from util.utils import set_seed, t2n, to01, compose_wt_simple | |
from util.metric import Metric | |
from config_ssl_upload import ex | |
from tqdm.auto import tqdm | |
# import Tensor | |
from torch import Tensor | |
from typing import List, Tuple, Union, cast, Iterable, Set, Any, Callable, TypeVar | |
def get_dice_loss(prediction: torch.Tensor, target: torch.Tensor, smooth=1.0): | |
''' | |
prediction: (B, 1, H, W) | |
target: (B, H, W) | |
''' | |
if prediction.shape[1] > 1: | |
# use only the foreground prediction | |
prediction = prediction[:, 1, :, :] | |
prediction = torch.sigmoid(prediction) | |
intersection = (prediction * target).sum(dim=(-2, -1)) | |
union = prediction.sum(dim=(-2, -1)) + target.sum(dim=(1, 2)) + smooth | |
dice = (2.0 * intersection + smooth) / union | |
dice_loss = 1.0 - dice.mean() | |
return dice_loss | |
def get_train_transforms(_config): | |
tr_transforms = myaug.transform_with_label( | |
{'aug': myaug.get_aug(_config['which_aug'], _config['input_size'][0])}) | |
return tr_transforms | |
def get_dataset_base_name(data_name): | |
if data_name == 'SABS_Superpix': | |
baseset_name = 'SABS' | |
elif data_name == 'C0_Superpix': | |
raise NotImplementedError | |
baseset_name = 'C0' | |
elif data_name == 'CHAOST2_Superpix': | |
baseset_name = 'CHAOST2' | |
elif data_name == 'CHAOST2_Superpix_672': | |
baseset_name = 'CHAOST2' | |
elif data_name == 'SABS_Superpix_448': | |
baseset_name = 'SABS' | |
elif data_name == 'SABS_Superpix_672': | |
baseset_name = 'SABS' | |
elif 'lits' in data_name.lower(): | |
baseset_name = 'LITS17' | |
else: | |
raise ValueError(f'Dataset: {data_name} not found') | |
return baseset_name | |
def get_nii_dataset(_config): | |
data_name = _config['dataset'] | |
baseset_name = get_dataset_base_name(data_name) | |
tr_transforms = get_train_transforms(_config) | |
tr_parent = SuperpixelDataset( # base dataset | |
which_dataset=baseset_name, | |
base_dir=_config['path'][data_name]['data_dir'], | |
idx_split=_config['eval_fold'], | |
mode='train', | |
# dummy entry for superpixel dataset | |
min_fg=str(_config["min_fg_data"]), | |
image_size=_config["input_size"][0], | |
transforms=tr_transforms, | |
nsup=_config['task']['n_shots'], | |
scan_per_load=_config['scan_per_load'], | |
exclude_list=_config["exclude_cls_list"], | |
superpix_scale=_config["superpix_scale"], | |
fix_length=_config["max_iters_per_load"] if (data_name == 'C0_Superpix') or ( | |
data_name == 'CHAOST2_Superpix') else _config["max_iters_per_load"], | |
use_clahe=_config['use_clahe'], | |
use_3_slices=_config["use_3_slices"], | |
tile_z_dim=3 if not _config["use_3_slices"] else 1, | |
) | |
return tr_parent | |
def get_dataset(_config): | |
return get_nii_dataset(_config) | |
def main(_run, _config, _log): | |
precision = torch.float32 | |
torch.autograd.set_detect_anomaly(True) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if _run.observers: | |
os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) | |
for source_file, _ in _run.experiment_info['sources']: | |
os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), | |
exist_ok=True) | |
_run.observers[0].save_file(source_file, f'source/{source_file}') | |
shutil.rmtree(f'{_run.observers[0].basedir}/_sources') | |
set_seed(_config['seed']) | |
writer = SummaryWriter(f'{_run.observers[0].dir}/logs') | |
_log.info('###### Create model ######') | |
if _config['reload_model_path'] != '': | |
_log.info(f'###### Reload model {_config["reload_model_path"]} ######') | |
else: | |
_config['reload_model_path'] = None | |
model = FewShotSeg(image_size=_config['input_size'][0], pretrained_path=_config['reload_model_path'], cfg=_config['model']) | |
model = model.to(device, precision) | |
model.train() | |
_log.info('###### Load data ######') | |
data_name = _config['dataset'] | |
tr_parent = get_dataset(_config) | |
# dataloaders | |
trainloader = DataLoader( | |
tr_parent, | |
batch_size=_config['batch_size'], | |
shuffle=True, | |
num_workers=_config['num_workers'], | |
pin_memory=True, | |
drop_last=True | |
) | |
_log.info('###### Set optimizer ######') | |
if _config['optim_type'] == 'sgd': | |
optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) | |
elif _config['optim_type'] == 'adam': | |
optimizer = torch.optim.AdamW( | |
model.parameters(), lr=_config['lr'], eps=1e-5) | |
else: | |
raise NotImplementedError | |
scheduler = MultiStepLR( | |
optimizer, milestones=_config['lr_milestones'], gamma=_config['lr_step_gamma']) | |
my_weight = compose_wt_simple(_config["use_wce"], data_name) | |
criterion = nn.CrossEntropyLoss( | |
ignore_index=_config['ignore_label'], weight=my_weight) | |
i_iter = 0 # total number of iteration | |
# number of times for reloading | |
n_sub_epoches = max(1, _config['n_steps'] // _config['max_iters_per_load'], _config["epochs"]) | |
log_loss = {'loss': 0, 'align_loss': 0} | |
_log.info('###### Training ######') | |
epoch_losses = [] | |
for sub_epoch in range(1): | |
print(f"Epoch: {sub_epoch}") | |
_log.info( | |
f'###### This is epoch {sub_epoch} of {n_sub_epoches} epoches ######') | |
pbar = tqdm(trainloader) | |
optimizer.zero_grad() | |
for idx, sample_batched in enumerate(tqdm(trainloader)): | |
losses = [] | |
i_iter += 1 | |
support_images = [[shot.to(device, precision) for shot in way] | |
for way in sample_batched['support_images']] | |
support_fg_mask = [[shot[f'fg_mask'].float().to(device, precision) for shot in way] | |
for way in sample_batched['support_mask']] | |
support_bg_mask = [[shot[f'bg_mask'].float().to(device, precision) for shot in way] | |
for way in sample_batched['support_mask']] | |
query_images = [query_image.to(device, precision) | |
for query_image in sample_batched['query_images']] | |
query_labels = torch.cat( | |
[query_label.long().to(device) for query_label in sample_batched['query_labels']], dim=0) | |
loss = 0.0 | |
try: | |
out = model(support_images, support_fg_mask, support_bg_mask, | |
query_images, isval=False, val_wsize=None) | |
query_pred, align_loss, _, _, _, _, _ = out | |
# pred = np.array(query_pred.argmax(dim=1)[0].cpu()) | |
except Exception as e: | |
print(f'faulty batch detected, skip: {e}') | |
# offload cuda memory | |
del support_images, support_fg_mask, support_bg_mask, query_images, query_labels | |
continue | |
query_loss = criterion(query_pred.float(), query_labels.long()) | |
loss += query_loss + align_loss | |
pbar.set_postfix({'loss': loss.item()}) | |
loss.backward() | |
if (idx + 1) % _config['grad_accumulation_steps'] == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
scheduler.step() | |
losses.append(loss.item()) | |
query_loss = query_loss.detach().data.cpu().numpy() | |
align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0 | |
_run.log_scalar('loss', query_loss) | |
_run.log_scalar('align_loss', align_loss) | |
log_loss['loss'] += query_loss | |
log_loss['align_loss'] += align_loss | |
# print loss and take snapshots | |
if (i_iter + 1) % _config['print_interval'] == 0: | |
writer.add_scalar('loss', loss, i_iter) | |
writer.add_scalar('query_loss', query_loss, i_iter) | |
writer.add_scalar('align_loss', align_loss, i_iter) | |
loss = log_loss['loss'] / _config['print_interval'] | |
align_loss = log_loss['align_loss'] / _config['print_interval'] | |
log_loss['loss'] = 0 | |
log_loss['align_loss'] = 0 | |
print( | |
f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss},') | |
if (i_iter + 1) % _config['save_snapshot_every'] == 0: | |
_log.info('###### Taking snapshot ######') | |
torch.save(model.state_dict(), | |
os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) | |
if (i_iter - 1) >= _config['n_steps']: | |
break # finish up | |
epoch_losses.append(np.mean(losses)) | |
print(f"Epoch {sub_epoch} loss: {np.mean(losses)}") | |
# Save the final model regardless of iteration count | |
_log.info('###### Saving final model ######') | |
final_save_path = os.path.join(f'{_run.observers[0].dir}/snapshots', f'final_model.pth') | |
torch.save(model.state_dict(), final_save_path) | |
print(f"Final model saved to: {final_save_path}") | |