LEURN / TRAINER.py
CaglarAytekin
first commit
b9ba714
"""
@author: Caglar Aytekin
contact: caglar@deepcause.ai
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score as accuracy
from sklearn.metrics import roc_auc_score
from torch.optim.lr_scheduler import StepLR
import numpy as np
import copy
class Trainer:
def __init__(self, model, X_train, X_val, y_train, y_val,lr,batch_size,epochs,problem_type,verbose=True):
self.model = model
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
self.problem_type=problem_type
self.verbose=verbose
if self.problem_type==0:
self.criterion = nn.MSELoss()
elif self.problem_type==1:
self.criterion = nn.BCEWithLogitsLoss()
elif self.problem_type==2:
self.criterion = nn.CrossEntropyLoss()
y_train=y_train.squeeze().long()
y_val=y_val.squeeze().long()
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
self.train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
self.val_loader = DataLoader(dataset=val_dataset, batch_size=len(val_dataset), shuffle=False)
self.batch_size=batch_size
self.epochs=epochs
self.best_metric = float('inf') if problem_type == 0 else float('-inf')
self.scheduler = StepLR(self.optimizer, step_size=epochs//3, gamma=0.2)
def train_epoch(self):
self.model.train()
total_loss = 0
total=0
correct=0
for inputs, labels in self.train_loader:
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)# + torch.sum(torch.abs(self.model.causal_discovery()))*1
loss.backward()
self.optimizer.step()
total_loss += loss.item()
total += len(labels.squeeze())
if self.problem_type==1:
correct += (torch.round(torch.sigmoid(outputs.data)).squeeze() == labels.squeeze()).sum().item()
elif self.problem_type==2:
correct += (torch.max(outputs.data, 1)[1] == labels.squeeze()).sum().item()
return total_loss/len(self.train_loader) , correct/total
def validate(self):
self.model.eval()
val_loss = 0
total=0
val_predictions = []
val_targets = []
with torch.no_grad():
for inputs, labels in self.val_loader:
outputs = self.model(inputs)
val_loss += self.criterion(outputs, labels).item()
total += len(labels.squeeze())
if self.problem_type==1:
val_predictions.extend(torch.sigmoid(outputs).view(-1).cpu().numpy())
elif self.problem_type==2:
val_predictions.extend(torch.max(outputs.data, 1)[1].view(-1).cpu().numpy())
val_targets.extend(labels.view(-1).cpu().numpy())
if self.problem_type==1:
val_roc_auc =roc_auc_score(val_targets, val_predictions)
val_acc = accuracy(val_targets, np.round(val_predictions))
elif self.problem_type==2:
val_acc = accuracy(val_targets,val_predictions)
val_roc_auc=0
else:
val_roc_auc=0
val_acc=0
return val_loss /len(self.val_loader), val_acc,val_roc_auc
def train(self):
for epoch in range(self.epochs):
#Increase alpha up to 1-tenth of entire epochs
alpha_now=np.minimum(1.0,float(epoch)/float(self.epochs/10))
# print(alpha_now)
self.model.set_alpha(alpha_now)
if epoch>self.epochs//10:
save_permit=True
else:
save_permit=False
tr_loss, tr_acc = self.train_epoch()
val_loss, val_acc , val_roc_auc= self.validate()
if self.problem_type == 0:
if self.verbose:
print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Val Loss {val_loss:.4f}')
if (val_loss < self.best_metric)and(save_permit):
self.best_metric = val_loss
# Save model checkpoint
self.model.nninput=None #Delete data remaining from training
self.encodings=None
self.taus=None
# torch.save(self.model, 'best_model.pth')
# torch.save(self.model.state_dict(), 'best_model_weights.pth')
self.best_model=copy.deepcopy(self.model.state_dict())
# print("Saving model with best validation loss.")
# Problem type 1: Focus on loss, accuracy, and AUC
elif self.problem_type == 1:
if self.verbose:
print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Train Acc {tr_acc:.4f}, Val Loss {val_loss:.4f}, Val Acc {val_acc:.4f}, Val ROC AUC {val_roc_auc:.4f}')
if (val_roc_auc > self.best_metric)and(save_permit):
self.best_metric = val_roc_auc
# Save model checkpoint
self.model.nninput=None #Delete data remaining from training
self.encodings=None
self.taus=None
# torch.save(self.model, 'best_model.pth')
# torch.save(self.model.state_dict(), 'best_model_weights.pth')
self.best_model=copy.deepcopy(self.model.state_dict())
# print("Saving model with best validation ROC AUC.")
# Problem type 2: Focus on loss and accuracy
elif self.problem_type == 2:
if self.verbose:
print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Train Acc {tr_acc:.4f}, Val Loss {val_loss:.4f}, Val Acc {val_acc:.4f}')
if (val_acc > self.best_metric)and(save_permit):
self.best_metric = val_acc
# Save model checkpoint
self.model.nninput=None #Delete data remaining from training
self.encodings=None
self.taus=None
# torch.save(self.model, 'best_model.pth')
# torch.save(self.model.state_dict(), 'best_model_weights.pth')
self.best_model=copy.deepcopy(self.model.state_dict())
# print("Saving model with best validation accuracy.")
self.scheduler.step()
# Load best validation model
self.model.load_state_dict(self.best_model)
# self.model = torch.load('best_model.pth')
def evaluate(self,X_test, y_test,verbose=True):
test_loader=DataLoader(dataset=TensorDataset(X_test, y_test), batch_size=len(y_test), shuffle=True)
self.model.eval()
test_loss = 0
total=0
test_predictions = []
test_targets = []
with torch.no_grad():
for inputs, labels in test_loader:
outputs = self.model(inputs)
test_loss += self.criterion(outputs, labels).item()
total += len(labels.squeeze())
if self.problem_type==1:
test_predictions.extend(torch.sigmoid(outputs).view(-1).cpu().numpy())
elif self.problem_type==2:
test_predictions.extend(torch.max(outputs.data, 1)[1].view(-1).cpu().numpy())
test_targets.extend(labels.view(-1).cpu().numpy())
if self.problem_type==1:
test_roc_auc =roc_auc_score(test_targets, test_predictions)
test_acc = accuracy(test_targets, np.round(test_predictions))
if verbose:
print('ROC-AUC: ', test_roc_auc)
return test_roc_auc
elif self.problem_type==2:
test_acc = accuracy(test_targets,test_predictions)
test_roc_auc=0
if verbose:
print('ACC: ', test_acc)
return test_acc
else:
test_roc_auc=0
test_acc=0
if verbose:
print('MSE: ', test_loss /len(test_loader))
return test_loss /len(test_loader)