Spaces:
Running
Running
import argparse | |
import os | |
import subprocess | |
import sys | |
import warnings | |
import pandas as pd | |
import torch | |
from datasets.utils.logging import disable_progress_bar | |
from transformers import AutoTokenizer | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from train import preprocess_df, train_loop | |
from utils import get_logger, seed_everything | |
# Suppress warnings and logging | |
warnings.filterwarnings("ignore") | |
disable_progress_bar() | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
def parse_args(): | |
""" | |
Parse command line arguments. | |
""" | |
parser = argparse.ArgumentParser( | |
description="Training script for ReactionT5Yield model." | |
) | |
parser.add_argument( | |
"--train_data_path", | |
type=str, | |
required=True, | |
help="Path to training data CSV file.", | |
) | |
parser.add_argument( | |
"--valid_data_path", | |
type=str, | |
required=True, | |
help="Path to validation data CSV file.", | |
) | |
parser.add_argument( | |
"--similar_reaction_data_path", | |
type=str, | |
required=False, | |
help="Path to similar data CSV.", | |
) | |
parser.add_argument( | |
"--pretrained_model_name_or_path", | |
type=str, | |
default="sagawa/CompoundT5", | |
help="Pretrained model name or path.", | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
type=str, | |
help="The model's name or path used for fine-tuning.", | |
) | |
parser.add_argument( | |
"--download_pretrained_model", | |
action="store_true", | |
default=False, | |
required=False, | |
help="Download pretrained model from hugging face hub and use it for fine-tuning.", | |
) | |
parser.add_argument("--debug", action="store_true", help="Enable debug mode.") | |
parser.add_argument( | |
"--epochs", type=int, default=200, help="Number of training epochs." | |
) | |
parser.add_argument( | |
"--patience", type=int, default=10, help="Early stopping patience." | |
) | |
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") | |
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") | |
parser.add_argument( | |
"--input_max_length", type=int, default=300, help="Maximum input token length." | |
) | |
parser.add_argument( | |
"--num_workers", type=int, default=4, help="Number of data loading workers." | |
) | |
parser.add_argument( | |
"--fc_dropout", | |
type=float, | |
default=0.0, | |
help="Dropout rate after fully connected layers.", | |
) | |
parser.add_argument( | |
"--eps", type=float, default=1e-6, help="Epsilon for Adam optimizer." | |
) | |
parser.add_argument( | |
"--weight_decay", type=float, default=0.05, help="Weight decay for optimizer." | |
) | |
parser.add_argument( | |
"--max_grad_norm", | |
type=int, | |
default=1000, | |
help="Maximum gradient norm for clipping.", | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Gradient accumulation steps.", | |
) | |
parser.add_argument( | |
"--num_warmup_steps", type=int, default=0, help="Number of warmup steps." | |
) | |
parser.add_argument( | |
"--batch_scheduler", action="store_true", help="Use batch scheduler." | |
) | |
parser.add_argument( | |
"--print_freq", type=int, default=100, help="Logging frequency." | |
) | |
parser.add_argument( | |
"--use_amp", | |
action="store_true", | |
help="Use automatic mixed precision for training.", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./", | |
help="Directory to save the trained model.", | |
) | |
parser.add_argument( | |
"--seed", type=int, default=42, help="Random seed for reproducibility." | |
) | |
parser.add_argument( | |
"--sampling_num", | |
type=int, | |
default=-1, | |
help="Number of samples used for training. If you want to use all samples, set -1.", | |
) | |
parser.add_argument( | |
"--sampling_frac", | |
type=float, | |
default=-1.0, | |
help="Ratio of samples used for training. If you want to use all samples, set -1.0.", | |
) | |
parser.add_argument( | |
"--checkpoint", | |
type=str, | |
help="Path to the checkpoint file for resuming training.", | |
) | |
return parser.parse_args() | |
def download_pretrained_model(): | |
""" | |
Download the pretrained model from Hugging Face. | |
""" | |
subprocess.run( | |
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/CompoundT5_best.pth", | |
shell=True, | |
) | |
subprocess.run( | |
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/config.pth", | |
shell=True, | |
) | |
subprocess.run( | |
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/special_tokens_map.json", | |
shell=True, | |
) | |
subprocess.run( | |
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/tokenizer.json", | |
shell=True, | |
) | |
subprocess.run( | |
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/tokenizer_config.json", | |
shell=True, | |
) | |
if __name__ == "__main__": | |
CFG = parse_args() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
CFG.device = device | |
if not os.path.exists(CFG.output_dir): | |
os.makedirs(CFG.output_dir) | |
seed_everything(seed=CFG.seed) | |
if CFG.download_pretrained_model: | |
download_pretrained_model() | |
CFG.model_name_or_path = "." | |
train = pd.read_csv(CFG.train_data_path).drop_duplicates().reset_index(drop=True) | |
valid = pd.read_csv(CFG.valid_data_path).drop_duplicates().reset_index(drop=True) | |
train = preprocess_df(train, CFG) | |
valid = preprocess_df(valid, CFG) | |
if CFG.sampling_num > 0: | |
train = train.sample(n=CFG.sampling_num, random_state=CFG.seed).reset_index( | |
drop=True | |
) | |
elif CFG.sampling_frac > 0 and CFG.sampling_frac < 1: | |
train = train.sample(frac=CFG.sampling_frac, random_state=CFG.seed).reset_index( | |
drop=True | |
) | |
if CFG.similar_reaction_data_path: | |
similar = preprocess_df(pd.read_csv(CFG.similar_reaction_data_path), CFG) | |
print(len(train)) | |
train = pd.concat([train, similar], ignore_index=True) | |
print(len(train)) | |
LOGGER = get_logger(os.path.join(CFG.output_dir, "train")) | |
CFG.logger = LOGGER | |
tokenizer = AutoTokenizer.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) | |
if os.path.exists(CFG.model_name_or_path) | |
else CFG.model_name_or_path, | |
return_tensors="pt", | |
) | |
tokenizer.save_pretrained(CFG.output_dir) | |
CFG.tokenizer = tokenizer | |
train_loop(train, valid, CFG) | |