Spaces:
Runtime error
Runtime error
""" | |
@author: Caglar Aytekin | |
contact: caglar@deepcause.ai | |
""" | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader, TensorDataset | |
from sklearn.metrics import accuracy_score as accuracy | |
from sklearn.metrics import roc_auc_score | |
from torch.optim.lr_scheduler import StepLR | |
import numpy as np | |
import copy | |
class Trainer: | |
def __init__(self, model, X_train, X_val, y_train, y_val,lr,batch_size,epochs,problem_type,verbose=True): | |
self.model = model | |
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
self.problem_type=problem_type | |
self.verbose=verbose | |
if self.problem_type==0: | |
self.criterion = nn.MSELoss() | |
elif self.problem_type==1: | |
self.criterion = nn.BCEWithLogitsLoss() | |
elif self.problem_type==2: | |
self.criterion = nn.CrossEntropyLoss() | |
y_train=y_train.squeeze().long() | |
y_val=y_val.squeeze().long() | |
train_dataset = TensorDataset(X_train, y_train) | |
val_dataset = TensorDataset(X_val, y_val) | |
self.train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) | |
self.val_loader = DataLoader(dataset=val_dataset, batch_size=len(val_dataset), shuffle=False) | |
self.batch_size=batch_size | |
self.epochs=epochs | |
self.best_metric = float('inf') if problem_type == 0 else float('-inf') | |
self.scheduler = StepLR(self.optimizer, step_size=epochs//3, gamma=0.2) | |
def train_epoch(self): | |
self.model.train() | |
total_loss = 0 | |
total=0 | |
correct=0 | |
for inputs, labels in self.train_loader: | |
self.optimizer.zero_grad() | |
outputs = self.model(inputs) | |
loss = self.criterion(outputs, labels)# + torch.sum(torch.abs(self.model.causal_discovery()))*1 | |
loss.backward() | |
self.optimizer.step() | |
total_loss += loss.item() | |
total += len(labels.squeeze()) | |
if self.problem_type==1: | |
correct += (torch.round(torch.sigmoid(outputs.data)).squeeze() == labels.squeeze()).sum().item() | |
elif self.problem_type==2: | |
correct += (torch.max(outputs.data, 1)[1] == labels.squeeze()).sum().item() | |
return total_loss/len(self.train_loader) , correct/total | |
def validate(self): | |
self.model.eval() | |
val_loss = 0 | |
total=0 | |
val_predictions = [] | |
val_targets = [] | |
with torch.no_grad(): | |
for inputs, labels in self.val_loader: | |
outputs = self.model(inputs) | |
val_loss += self.criterion(outputs, labels).item() | |
total += len(labels.squeeze()) | |
if self.problem_type==1: | |
val_predictions.extend(torch.sigmoid(outputs).view(-1).cpu().numpy()) | |
elif self.problem_type==2: | |
val_predictions.extend(torch.max(outputs.data, 1)[1].view(-1).cpu().numpy()) | |
val_targets.extend(labels.view(-1).cpu().numpy()) | |
if self.problem_type==1: | |
val_roc_auc =roc_auc_score(val_targets, val_predictions) | |
val_acc = accuracy(val_targets, np.round(val_predictions)) | |
elif self.problem_type==2: | |
val_acc = accuracy(val_targets,val_predictions) | |
val_roc_auc=0 | |
else: | |
val_roc_auc=0 | |
val_acc=0 | |
return val_loss /len(self.val_loader), val_acc,val_roc_auc | |
def train(self): | |
for epoch in range(self.epochs): | |
#Increase alpha up to 1-tenth of entire epochs | |
alpha_now=np.minimum(1.0,float(epoch)/float(self.epochs/10)) | |
# print(alpha_now) | |
self.model.set_alpha(alpha_now) | |
if epoch>self.epochs//10: | |
save_permit=True | |
else: | |
save_permit=False | |
tr_loss, tr_acc = self.train_epoch() | |
val_loss, val_acc , val_roc_auc= self.validate() | |
if self.problem_type == 0: | |
if self.verbose: | |
print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Val Loss {val_loss:.4f}') | |
if (val_loss < self.best_metric)and(save_permit): | |
self.best_metric = val_loss | |
# Save model checkpoint | |
self.model.nninput=None #Delete data remaining from training | |
self.encodings=None | |
self.taus=None | |
# torch.save(self.model, 'best_model.pth') | |
# torch.save(self.model.state_dict(), 'best_model_weights.pth') | |
self.best_model=copy.deepcopy(self.model.state_dict()) | |
# print("Saving model with best validation loss.") | |
# Problem type 1: Focus on loss, accuracy, and AUC | |
elif self.problem_type == 1: | |
if self.verbose: | |
print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Train Acc {tr_acc:.4f}, Val Loss {val_loss:.4f}, Val Acc {val_acc:.4f}, Val ROC AUC {val_roc_auc:.4f}') | |
if (val_roc_auc > self.best_metric)and(save_permit): | |
self.best_metric = val_roc_auc | |
# Save model checkpoint | |
self.model.nninput=None #Delete data remaining from training | |
self.encodings=None | |
self.taus=None | |
# torch.save(self.model, 'best_model.pth') | |
# torch.save(self.model.state_dict(), 'best_model_weights.pth') | |
self.best_model=copy.deepcopy(self.model.state_dict()) | |
# print("Saving model with best validation ROC AUC.") | |
# Problem type 2: Focus on loss and accuracy | |
elif self.problem_type == 2: | |
if self.verbose: | |
print(f'Epoch {epoch}: Train Loss {tr_loss:.4f}, Train Acc {tr_acc:.4f}, Val Loss {val_loss:.4f}, Val Acc {val_acc:.4f}') | |
if (val_acc > self.best_metric)and(save_permit): | |
self.best_metric = val_acc | |
# Save model checkpoint | |
self.model.nninput=None #Delete data remaining from training | |
self.encodings=None | |
self.taus=None | |
# torch.save(self.model, 'best_model.pth') | |
# torch.save(self.model.state_dict(), 'best_model_weights.pth') | |
self.best_model=copy.deepcopy(self.model.state_dict()) | |
# print("Saving model with best validation accuracy.") | |
self.scheduler.step() | |
# Load best validation model | |
self.model.load_state_dict(self.best_model) | |
# self.model = torch.load('best_model.pth') | |
def evaluate(self,X_test, y_test,verbose=True): | |
test_loader=DataLoader(dataset=TensorDataset(X_test, y_test), batch_size=len(y_test), shuffle=True) | |
self.model.eval() | |
test_loss = 0 | |
total=0 | |
test_predictions = [] | |
test_targets = [] | |
with torch.no_grad(): | |
for inputs, labels in test_loader: | |
outputs = self.model(inputs) | |
test_loss += self.criterion(outputs, labels).item() | |
total += len(labels.squeeze()) | |
if self.problem_type==1: | |
test_predictions.extend(torch.sigmoid(outputs).view(-1).cpu().numpy()) | |
elif self.problem_type==2: | |
test_predictions.extend(torch.max(outputs.data, 1)[1].view(-1).cpu().numpy()) | |
test_targets.extend(labels.view(-1).cpu().numpy()) | |
if self.problem_type==1: | |
test_roc_auc =roc_auc_score(test_targets, test_predictions) | |
test_acc = accuracy(test_targets, np.round(test_predictions)) | |
if verbose: | |
print('ROC-AUC: ', test_roc_auc) | |
return test_roc_auc | |
elif self.problem_type==2: | |
test_acc = accuracy(test_targets,test_predictions) | |
test_roc_auc=0 | |
if verbose: | |
print('ACC: ', test_acc) | |
return test_acc | |
else: | |
test_roc_auc=0 | |
test_acc=0 | |
if verbose: | |
print('MSE: ', test_loss /len(test_loader)) | |
return test_loss /len(test_loader) | |