import argparse import gc import glob import os import sys import time import warnings from pathlib import Path import numpy as np import pandas as pd import torch import torch.nn as nn from datasets.utils.logging import disable_progress_bar from sklearn.metrics import mean_squared_error, r2_score from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from transformers import AutoTokenizer, get_linear_schedule_with_warmup # Append the utils module path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from generation_utils import prepare_input from models import ReactionT5Yield from rdkit import RDLogger from utils import ( AverageMeter, add_new_tokens, canonicalize, filter_out, get_logger, get_optimizer_params, seed_everything, space_clean, timeSince, ) # Suppress warnings and logging warnings.filterwarnings("ignore") RDLogger.DisableLog("rdApp.*") disable_progress_bar() os.environ["TOKENIZERS_PARALLELISM"] = "false" def parse_args(): """ Parse command line arguments. """ parser = argparse.ArgumentParser( description="Training script for ReactionT5Yield model." ) parser.add_argument( "--train_data_path", type=str, required=True, help="Path to training data CSV file.", ) parser.add_argument( "--valid_data_path", type=str, required=True, help="Path to validation data CSV file.", ) parser.add_argument( "--test_data_path", type=str, help="Path to testing data CSV file.", ) parser.add_argument( "--CN_test_data_path", type=str, help="Path to CN testing data CSV file.", ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default="sagawa/CompoundT5", help="Pretrained model name or path.", ) parser.add_argument( "--model_name_or_path", type=str, help="The model's name or path used for fine-tuning.", ) parser.add_argument("--debug", action="store_true", help="Enable debug mode.") parser.add_argument( "--epochs", type=int, default=5, help="Number of training epochs." ) parser.add_argument( "--patience", type=int, default=10, help="Early stopping patience." ) parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate.") parser.add_argument("--batch_size", type=int, default=5, help="Batch size.") parser.add_argument( "--input_max_length", type=int, default=400, help="Maximum input token length." ) parser.add_argument( "--num_workers", type=int, default=4, help="Number of data loading workers." ) parser.add_argument( "--fc_dropout", type=float, default=0.0, help="Dropout rate after fully connected layers.", ) parser.add_argument( "--eps", type=float, default=1e-6, help="Epsilon for Adam optimizer." ) parser.add_argument( "--weight_decay", type=float, default=0.05, help="Weight decay for optimizer." ) parser.add_argument( "--max_grad_norm", type=int, default=1000, help="Maximum gradient norm for clipping.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.", ) parser.add_argument( "--num_warmup_steps", type=int, default=0, help="Number of warmup steps." ) parser.add_argument( "--batch_scheduler", action="store_true", help="Use batch scheduler." ) parser.add_argument( "--print_freq", type=int, default=100, help="Logging frequency." ) parser.add_argument( "--use_amp", action="store_true", help="Use automatic mixed precision for training.", ) parser.add_argument( "--output_dir", type=str, default="./", help="Directory to save the trained model.", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducibility." ) parser.add_argument( "--sampling_num", type=int, default=-1, help="Number of samples used for training. If you want to use all samples, set -1.", ) parser.add_argument( "--sampling_frac", type=float, default=-1.0, help="Ratio of samples used for training. If you want to use all samples, set -1.0.", ) parser.add_argument( "--checkpoint", type=str, help="Path to the checkpoint file for resuming training.", ) return parser.parse_args() def preprocess_df(df, cfg, drop_duplicates=True): """ Preprocess the input DataFrame for training. Args: df (pd.DataFrame): Input DataFrame. cfg (argparse.Namespace): Configuration object. Returns: pd.DataFrame: Preprocessed DataFrame. """ if "YIELD" in df.columns: # if max yield is 100, then normalize to [0, 1] if df["YIELD"].max() >= 100: df["YIELD"] = df["YIELD"].clip(0, 100) / 100 else: df["YIELD"] = None for col in ["REACTANT", "PRODUCT", "CATALYST", "REAGENT", "SOLVENT"]: if col not in df.columns: df[col] = None df[col] = df[col].fillna(" ") df["REAGENT"] = df["CATALYST"] + "." + df["REAGENT"] for col in ["REAGENT", "REACTANT", "PRODUCT"]: df[col] = df[col].apply(lambda x: space_clean(x)) df[col] = df[col].apply(lambda x: canonicalize(x) if x != " " else " ") df = df[~df[col].isna()].reset_index(drop=True) df[col] = df[col].apply(lambda x: ".".join(sorted(x.split(".")))) df["input"] = ( "REACTANT:" + df["REACTANT"] + "REAGENT:" + df["REAGENT"] + "PRODUCT:" + df["PRODUCT"] ) if drop_duplicates: df = df.loc[df[["input", "YIELD"]].drop_duplicates().index].reset_index( drop=True ) if cfg.debug: df = df.head(1000) return df def preprocess_CN(df): """ Preprocess the CN test DataFrame. Args: df (pd.DataFrame): Input DataFrame. Returns: pd.DataFrame: Preprocessed DataFrame. """ df["REACTANT"] = df["REACTANT"].apply(lambda x: ".".join(sorted(x.split(".")))) df["REAGENT"] = df["REAGENT"].apply(lambda x: ".".join(sorted(x.split(".")))) df["PRODUCT"] = df["PRODUCT"].apply(lambda x: ".".join(sorted(x.split(".")))) df["input"] = ( "REACTANT:" + df["REACTANT"] + "REAGENT:" + df["REAGENT"] + "PRODUCT:" + df["PRODUCT"] ) df["pair"] = df["input"] return df class TrainDataset(Dataset): """ Dataset class for training. """ def __init__(self, cfg, df): self.cfg = cfg self.inputs = df["input"].values self.labels = df["YIELD"].values def __len__(self): return len(self.labels) def __getitem__(self, item): inputs = prepare_input(self.cfg, self.inputs[item]) label = torch.tensor(self.labels[item], dtype=torch.float) return inputs, label def save_checkpoint(state, filename="checkpoint.pth.tar"): """ Save model checkpoint. Args: state (dict): Checkpoint state. filename (str): Filename to save the checkpoint. """ torch.save(state, filename) def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, cfg): """ Training function for one epoch. Args: train_loader (DataLoader): DataLoader for training data. model (nn.Module): Model to be trained. criterion (nn.Module): Loss function. optimizer (Optimizer): Optimizer. epoch (int): Current epoch. scheduler (Scheduler): Learning rate scheduler. cfg (argparse.Namespace): Configuration object. Returns: float: Average training loss. """ model.train() scaler = torch.amp.GradScaler(enabled=cfg.use_amp) losses = AverageMeter() start = time.time() for step, (inputs, labels) in enumerate(train_loader): inputs = {k: v.to(cfg.device) for k, v in inputs.items()} labels = labels.to(cfg.device) batch_size = labels.size(0) with torch.autocast(cfg.device, enabled=cfg.use_amp): y_preds = model(inputs) loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1)) if cfg.gradient_accumulation_steps > 1: loss /= cfg.gradient_accumulation_steps losses.update(loss.item(), batch_size) scaler.scale(loss).backward() grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), cfg.max_grad_norm ) if (step + 1) % cfg.gradient_accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() if cfg.batch_scheduler: scheduler.step() if step % cfg.print_freq == 0 or step == (len(train_loader) - 1): print( f"Epoch: [{epoch + 1}][{step}/{len(train_loader)}] " f"Elapsed {timeSince(start, float(step + 1) / len(train_loader))} " f"Loss: {losses.val:.4f}({losses.avg:.4f}) " f"Grad: {grad_norm:.4f} " f"LR: {scheduler.get_lr()[0]:.8f}" ) return losses.avg def valid_fn(valid_loader, model, cfg): """ Validation function. Args: valid_loader (DataLoader): DataLoader for validation data. model (nn.Module): Model to be validated. cfg (argparse.Namespace): Configuration object. Returns: tuple: Validation loss and R^2 score. """ model.eval() start = time.time() label_list = [] pred_list = [] for step, (inputs, labels) in enumerate(valid_loader): inputs = {k: v.to(cfg.device) for k, v in inputs.items()} with torch.no_grad(): y_preds = model(inputs) label_list.extend(labels.tolist()) pred_list.extend(y_preds.tolist()) if step % cfg.print_freq == 0 or step == (len(valid_loader) - 1): print( f"EVAL: [{step}/{len(valid_loader)}] " f"Elapsed {timeSince(start, float(step + 1) / len(valid_loader))} " f"RMSE Loss: {np.sqrt(mean_squared_error(label_list, pred_list)):.4f} " f"R^2 Score: {r2_score(label_list, pred_list):.4f}" ) return mean_squared_error(label_list, pred_list), r2_score(label_list, pred_list) def train_loop(train_ds, valid_ds, cfg): """ Training loop. Args: train_ds (pd.DataFrame): Training data. valid_ds (pd.DataFrame): Validation data. """ train_dataset = TrainDataset(cfg, train_ds) valid_dataset = TrainDataset(cfg, valid_ds) train_loader = DataLoader( train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True, ) valid_loader = DataLoader( valid_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=False, ) if not cfg.model_name_or_path: model = ReactionT5Yield(cfg, config_path=None, pretrained=True) torch.save(model.config, os.path.join(cfg.output_dir, "config.pth")) else: model = ReactionT5Yield( cfg, config_path=os.path.join(cfg.model_name_or_path, "config.pth"), pretrained=False, ) torch.save(model.config, os.path.join(cfg.output_dir, "config.pth")) pth_files = glob.glob(os.path.join(cfg.model_name_or_path, "*.pth")) for pth_file in pth_files: state = torch.load( pth_file, map_location=torch.device("cpu"), weights_only=False ) try: model.load_state_dict(state) break except: pass model.to(cfg.device) optimizer_parameters = get_optimizer_params( model, encoder_lr=cfg.lr, decoder_lr=cfg.lr, weight_decay=cfg.weight_decay ) optimizer = AdamW(optimizer_parameters, lr=cfg.lr, eps=cfg.eps, betas=(0.9, 0.999)) num_train_steps = int(len(train_ds) / cfg.batch_size * cfg.epochs) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, ) criterion = nn.MSELoss(reduction="mean") best_loss = float("inf") start_epoch = 0 es_count = 0 if cfg.checkpoint: checkpoint = torch.load(cfg.checkpoint) model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) best_loss = checkpoint["loss"] start_epoch = checkpoint["epoch"] + 1 es_count = checkpoint["es_count"] del checkpoint for epoch in range(start_epoch, cfg.epochs): start_time = time.time() avg_loss = train_fn( train_loader, model, criterion, optimizer, epoch, scheduler, cfg ) val_loss, val_r2_score = valid_fn(valid_loader, model, cfg) elapsed = time.time() - start_time cfg.logger.info( f"Epoch {epoch + 1} - avg_train_loss: {avg_loss:.4f} val_rmse_loss: {val_loss:.4f} val_r2_score: {val_r2_score:.4f} time: {elapsed:.0f}s" ) if val_loss < best_loss: es_count = 0 best_loss = val_loss cfg.logger.info( f"Epoch {epoch + 1} - Save Lowest Loss: {best_loss:.4f} Model" ) torch.save( model.state_dict(), os.path.join( cfg.output_dir, f"{cfg.pretrained_model_name_or_path.split('/')[-1]}_best.pth", ), ) else: es_count += 1 if es_count >= cfg.patience: print("Early stopping") break save_checkpoint( { "epoch": epoch, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "loss": best_loss, "es_count": es_count, }, filename=os.path.join(cfg.output_dir, "checkpoint.pth.tar"), ) torch.cuda.empty_cache() gc.collect() if __name__ == "__main__": CFG = parse_args() CFG.batch_scheduler = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") CFG.device = device if not os.path.exists(CFG.output_dir): os.makedirs(CFG.output_dir) seed_everything(seed=CFG.seed) train = preprocess_df( filter_out(pd.read_csv(CFG.train_data_path), ["YIELD", "REACTANT", "PRODUCT"]), CFG, ) valid = preprocess_df( filter_out(pd.read_csv(CFG.valid_data_path), ["YIELD", "REACTANT", "PRODUCT"]), CFG, ) if CFG.CN_test_data_path: train_copy = preprocess_CN(train.copy()) CN_test = preprocess_CN(pd.read_csv(CFG.CN_test_data_path)) print(len(train)) train = train[~train_copy["pair"].isin(CN_test["pair"])].reset_index(drop=True) print(len(train)) train["pair"] = train["input"] + " - " + train["YIELD"].astype(str) valid["pair"] = valid["input"] + " - " + valid["YIELD"].astype(str) valid = valid[~valid["pair"].isin(train["pair"])].reset_index(drop=True) if CFG.sampling_num > 0: train = train.sample(n=CFG.sampling_num, random_state=CFG.seed).reset_index( drop=True ) elif CFG.sampling_frac > 0: train = train.sample(frac=CFG.sampling_frac, random_state=CFG.seed).reset_index( drop=True ) train.to_csv("train.csv", index=False) valid.to_csv("valid.csv", index=False) if CFG.test_data_path: test = filter_out( pd.read_csv(CFG.test_data_path), ["YIELD", "REACTANT", "PRODUCT"] ) test = preprocess_df(test, CFG) test["pair"] = test["input"] + " - " + test["YIELD"].astype(str) test = test[~test["pair"].isin(train["pair"])].reset_index(drop=True) test = test.drop_duplicates(subset=["pair"]).reset_index(drop=True) test.to_csv("test.csv", index=False) LOGGER = get_logger(os.path.join(CFG.output_dir, "train")) CFG.logger = LOGGER # load tokenizer tokenizer = AutoTokenizer.from_pretrained( os.path.abspath(CFG.model_name_or_path) if os.path.exists(CFG.model_name_or_path) else CFG.model_name_or_path, return_tensors="pt", ) tokenizer = add_new_tokens( tokenizer, Path(__file__).resolve().parent.parent / "data" / "additional_tokens.txt", ) tokenizer.add_special_tokens( { "additional_special_tokens": tokenizer.additional_special_tokens + ["REACTANT:", "PRODUCT:", "REAGENT:"] } ) tokenizer.save_pretrained(CFG.output_dir) CFG.tokenizer = tokenizer train_loop(train, valid, CFG)