assi1 / ELR_plus /trainer /trainer.py
uthurumella's picture
Upload 69 files
72fc481 verified
import numpy as np
import torch
from tqdm import tqdm
from typing import List
from torchvision.utils import make_grid
from base import BaseTrainer
from utils import inf_loop, linear_rampup, sigmoid_rampup, linear_rampdown
import sys
from sklearn.mixture import GaussianMixture
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
class Trainer(BaseTrainer):
"""
Trainer class
Note:
Inherited from BaseTrainer.
"""
def __init__(self, model1, model2, model_ema1, model_ema2, train_criterion1, train_criterion2, metrics, optimizer1, optimizer2, config,
data_loader1, data_loader2,
valid_data_loader=None,
test_data_loader=None,
lr_scheduler1=None, lr_scheduler2=None,
len_epoch=None, val_criterion=None,
model_ema1_copy=None, model_ema2_copy=None):
super().__init__(model1, model2, model_ema1, model_ema2, train_criterion1, train_criterion2,
metrics, optimizer1, optimizer2, config, val_criterion, model_ema1_copy, model_ema2_copy)
self.config = config.config
self.data_loader1 = data_loader1
self.data_loader2 = data_loader2
if len_epoch is None:
# epoch-based training
self.len_epoch = len(self.data_loader1)
else:
# iteration-based training
self.data_loader1 = inf_loop(data_loader1)
self.data_loader2 = inf_loop(data_loader2)
self.len_epoch = len_epoch
self.valid_data_loader = valid_data_loader
self.test_data_loader = test_data_loader
self.do_validation = self.valid_data_loader is not None
self.do_test = self.test_data_loader is not None
self.lr_scheduler1 = lr_scheduler1
self.lr_scheduler2 = lr_scheduler2
self.log_step = int(np.sqrt(self.data_loader1.batch_size))
self.train_loss_list: List[float] = []
self.val_loss_list: List[float] = []
self.test_loss_list: List[float] = []
def _eval_metrics(self, output, target):
acc_metrics = np.zeros(len(self.metrics))
for i, metric in enumerate(self.metrics):
acc_metrics[i] += metric(output, target)
self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i])
return acc_metrics
def _train_epoch(self, epoch, model, model_ema, model_ema2, data_loader, train_criterion, optimizer, lr_scheduler, device = 'cpu', queue = None):
"""
Training logic for an epoch
:param epoch: Current training epoch.
:return: A log that contains all information you want to save.
Note:
If you have additional information to record, for example:
> additional_log = {"x": x, "y": y}
merge it with log before return. i.e.
> log = {**log, **additional_log}
> return log
The metrics in log must have the key 'metrics'.
"""
model.train()
model_ema.train()
total_loss = 0
total_metrics = np.zeros(len(self.metrics))
total_metrics_ema = np.zeros(len(self.metrics))
if hasattr(data_loader.dataset, 'num_raw_example'):
num_examp = data_loader.dataset.num_raw_example
else:
num_examp = len(data_loader.dataset)
local_step = 0
with tqdm(data_loader) as progress:
for batch_idx, (data, target, indexs, _) in enumerate(progress):
progress.set_description_str(f'Train epoch {epoch}')
data_original = data
target_original = target
target = torch.zeros(len(target), self.config['num_classes']).scatter_(1, target.view(-1,1), 1)
data, target, target_original = data.to(device), target.float().to(device), target_original.to(device)
data, target, mixup_l, mix_index = self._mixup_data(data, target, alpha = self.config['mixup_alpha'], device = device)
output = model(data)
data_original = data_original.to(device)
output_original = model_ema2(data_original)
output_original = output_original.data.detach()
train_criterion.update_hist(epoch, output_original, indexs.numpy().tolist(), mix_index = mix_index, mixup_l = mixup_l)
local_step += 1
loss, probs = train_criterion(self.global_step + local_step, output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
self.update_ema_variables(model, model_ema, self.global_step + local_step, self.config['ema_alpha'])
self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
self.writer.add_scalar('loss', loss.item())
self.train_loss_list.append(loss.item())
total_loss += loss.item()
total_metrics += self._eval_metrics(output, target.argmax(dim=1))
if output_original is not None:
total_metrics_ema += self._eval_metrics(output_original, target.argmax(dim=1))
if batch_idx % self.log_step == 0:
progress.set_postfix_str(' {} Loss: {:.6f}'.format(
self._progress(batch_idx),
loss.item()))
if batch_idx == self.len_epoch:
break
if hasattr(data_loader, 'run'):
data_loader.run()
log = {
'global step': self.global_step,
'local_step': local_step,
'loss': total_loss / self.len_epoch,
'metrics': (total_metrics / self.len_epoch).tolist(),
'metrics_ema': (total_metrics_ema / self.len_epoch).tolist(),
'learning rate': lr_scheduler.get_lr()
}
if lr_scheduler is not None:
lr_scheduler.step()
if queue is None:
return log
else:
queue.put(log)
def _valid_epoch(self, epoch, model1, model2, device = 'cpu', queue = None):
"""
Validate after training an epoch
:return: A log that contains information about validation
Note:
The validation metrics in log must have the key 'val_metrics'.
"""
model1.eval()
model2.eval()
total_val_loss = 0
total_val_metrics = np.zeros(len(self.metrics))
with torch.no_grad():
with tqdm(self.valid_data_loader) as progress:
for batch_idx, (data, target, _, _) in enumerate(progress):
progress.set_description_str(f'Valid epoch {epoch}')
data, target = data.to(device), target.to(device)
output1 = model1(data)
output2 = model2(data)
output = 0.5*(output1 + output2)
loss = self.val_criterion(output, target)
self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
self.writer.add_scalar('loss', loss.item())
self.val_loss_list.append(loss.item())
total_val_loss += loss.item()
total_val_metrics += self._eval_metrics(output, target)
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
# #add histogram of model parameters to the tensorboard
# for name, p in model.named_parameters():
# self.writer.add_histogram(name, p, bins='auto')
if queue is None:
return {
'val_loss': total_val_loss / len(self.valid_data_loader),
'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
}
else:
queue.put({
'val_loss': total_val_loss / len(self.valid_data_loader),
'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
})
def _test_epoch(self, epoch, model1, model2, device = 'cpu', queue = None):
"""
Test after training an epoch
:return: A log that contains information about test
Note:
The Test metrics in log must have the key 'val_metrics'.
"""
model1.eval()
model2.eval()
total_test_loss = 0
total_test_metrics = np.zeros(len(self.metrics))
with torch.no_grad():
with tqdm(self.test_data_loader) as progress:
for batch_idx, (data, target,indexs,_) in enumerate(progress):
progress.set_description_str(f'Test epoch {epoch}')
data, target = data.to(device), target.to(device)
output1 = model1(data)
output2 = model2(data)
output = 0.5*(output1 + output2)
loss = self.val_criterion(output, target)
self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test')
self.writer.add_scalar('loss', loss.item())
self.test_loss_list.append(loss.item())
total_test_loss += loss.item()
total_test_metrics += self._eval_metrics(output, target)
self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
#add histogram of model parameters to the tensorboard
for name, p in model1.named_parameters():
self.writer.add_histogram(name, p, bins='auto')
if queue is None:
return {
'test_loss': total_test_loss / len(self.test_data_loader),
'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
}
else:
queue.put({
'test_loss': total_test_loss / len(self.test_data_loader),
'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
})
def _warmup_epoch(self, epoch, model, data_loader, optimizer, train_criterion, lr_scheduler, device = 'cpu', queue = None):
total_loss = 0
total_metrics = np.zeros(len(self.metrics))
model.train()
with tqdm(data_loader) as progress:
for batch_idx, (data, target, indexs , _) in enumerate(progress):
progress.set_description_str(f'Train epoch {epoch}')
data, target = data.to(device), target.long().to(device)
optimizer.zero_grad()
output = model(data)
out_prob = output.data.detach()
train_criterion.update_hist(epoch, out_prob ,indexs.cpu().detach().numpy().tolist())
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
# self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
# self.writer.add_scalar('loss', loss.item())
self.train_loss_list.append(loss.item())
total_loss += loss.item()
total_metrics += self._eval_metrics(output, target)
if batch_idx % self.log_step == 0:
progress.set_postfix_str(' {} Loss: {:.6f}'.format(
self._progress(batch_idx),
loss.item()))
if batch_idx == self.len_epoch:
break
log = {
'loss': total_loss / self.len_epoch,
'noise detection rate' : 0.0,
'metrics': (total_metrics / self.len_epoch).tolist(),
'learning rate': lr_scheduler.get_lr()
}
if queue is None:
return log
else:
queue.put(log)
def eval_train(self, epoch, model_ema2, train_criterion):
#model.eval()
num_samples = args.num_batches*args.batch_size
losses = torch.zeros(num_samples)
with torch.no_grad():
for batch_idx, (inputs, targets, path) in enumerate(eval_loader):
inputs, targets = inputs.cuda(), targets.cuda()
output0 = model_ema2(inputs)
output0 = output0.data.detach()
output1, output2, output3 = None, None, None
train_criterion.update_hist(epoch, output0, output1, output2, output3, indexs.numpy().tolist(),mix_index = mix_index, mixup_l = mixup_l)
def update_ema_variables(self, model, model_ema, global_step, alpha_=0.997):
# Use the true average until the exponential average is more correct
if alpha_ == 0:
ema_param.data = param.data
else:
if self.config['ema_update']:
alpha = sigmoid_rampup(global_step + 1, self.config['ema_step'])*alpha_
else:
alpha = min(1 - 1 / (global_step + 1), alpha_)
for ema_param, param in zip(model_ema.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
def _progress(self, batch_idx):
base = '[{}/{} ({:.0f}%)]'
if hasattr(self.data_loader1, 'n_samples'):
current = batch_idx * self.data_loader1.batch_size
total = self.data_loader1.n_samples
else:
current = batch_idx
total = self.len_epoch
return base.format(current, total, 100.0 * current / total)
def _mixup_data(self, x, y, alpha=1.0, device = 'cpu'):
'''Returns mixed inputs, pairs of targets, and lambda'''
if alpha > 0:
lam = np.random.beta(alpha, alpha)
lam = max(lam, 1-lam)
batch_size = x.size()[0]
mix_index = torch.randperm(batch_size).to(device)
mixed_x = lam * x + (1 - lam) * x[mix_index, :]#
mixed_target = lam * y + (1 - lam) * y[mix_index, :]
return mixed_x, mixed_target, lam, mix_index
else:
lam = 1
return x, y, lam, ...
def _mixup_criterion(self, pred, y_a, y_b, lam, *args):
loss_a, prob_a, entropy_a= self.train_criterion(pred, y_a, *args)
loss_b, porb_b, entropy_b = self.train_criterion(pred, y_b, *args)
return lam * loss_a + (1 - lam) * loss_b, lam * prob_a + (1-lam) * porb_b, lam * entropy_a + (1-lam) * entropy_b