ReactionT5 / task_yield /finetune.py
sagawa's picture
Upload 42 files
08ccc8e verified
import argparse
import os
import subprocess
import sys
import warnings
import pandas as pd
import torch
from datasets.utils.logging import disable_progress_bar
from transformers import AutoTokenizer
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from train import preprocess_df, train_loop
from utils import get_logger, seed_everything
# Suppress warnings and logging
warnings.filterwarnings("ignore")
disable_progress_bar()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args():
"""
Parse command line arguments.
"""
parser = argparse.ArgumentParser(
description="Training script for ReactionT5Yield model."
)
parser.add_argument(
"--train_data_path",
type=str,
required=True,
help="Path to training data CSV file.",
)
parser.add_argument(
"--valid_data_path",
type=str,
required=True,
help="Path to validation data CSV file.",
)
parser.add_argument(
"--similar_reaction_data_path",
type=str,
required=False,
help="Path to similar data CSV.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default="sagawa/CompoundT5",
help="Pretrained model name or path.",
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="The model's name or path used for fine-tuning.",
)
parser.add_argument(
"--download_pretrained_model",
action="store_true",
default=False,
required=False,
help="Download pretrained model from hugging face hub and use it for fine-tuning.",
)
parser.add_argument("--debug", action="store_true", help="Enable debug mode.")
parser.add_argument(
"--epochs", type=int, default=200, help="Number of training epochs."
)
parser.add_argument(
"--patience", type=int, default=10, help="Early stopping patience."
)
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.")
parser.add_argument(
"--input_max_length", type=int, default=300, help="Maximum input token length."
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of data loading workers."
)
parser.add_argument(
"--fc_dropout",
type=float,
default=0.0,
help="Dropout rate after fully connected layers.",
)
parser.add_argument(
"--eps", type=float, default=1e-6, help="Epsilon for Adam optimizer."
)
parser.add_argument(
"--weight_decay", type=float, default=0.05, help="Weight decay for optimizer."
)
parser.add_argument(
"--max_grad_norm",
type=int,
default=1000,
help="Maximum gradient norm for clipping.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Gradient accumulation steps.",
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of warmup steps."
)
parser.add_argument(
"--batch_scheduler", action="store_true", help="Use batch scheduler."
)
parser.add_argument(
"--print_freq", type=int, default=100, help="Logging frequency."
)
parser.add_argument(
"--use_amp",
action="store_true",
help="Use automatic mixed precision for training.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./",
help="Directory to save the trained model.",
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for reproducibility."
)
parser.add_argument(
"--sampling_num",
type=int,
default=-1,
help="Number of samples used for training. If you want to use all samples, set -1.",
)
parser.add_argument(
"--sampling_frac",
type=float,
default=-1.0,
help="Ratio of samples used for training. If you want to use all samples, set -1.0.",
)
parser.add_argument(
"--checkpoint",
type=str,
help="Path to the checkpoint file for resuming training.",
)
return parser.parse_args()
def download_pretrained_model():
"""
Download the pretrained model from Hugging Face.
"""
subprocess.run(
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/CompoundT5_best.pth",
shell=True,
)
subprocess.run(
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/config.pth",
shell=True,
)
subprocess.run(
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/special_tokens_map.json",
shell=True,
)
subprocess.run(
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/tokenizer.json",
shell=True,
)
subprocess.run(
"wget https://huggingface.co/sagawa/ReactionT5v2-yield/resolve/main/tokenizer_config.json",
shell=True,
)
if __name__ == "__main__":
CFG = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CFG.device = device
if not os.path.exists(CFG.output_dir):
os.makedirs(CFG.output_dir)
seed_everything(seed=CFG.seed)
if CFG.download_pretrained_model:
download_pretrained_model()
CFG.model_name_or_path = "."
train = pd.read_csv(CFG.train_data_path).drop_duplicates().reset_index(drop=True)
valid = pd.read_csv(CFG.valid_data_path).drop_duplicates().reset_index(drop=True)
train = preprocess_df(train, CFG)
valid = preprocess_df(valid, CFG)
if CFG.sampling_num > 0:
train = train.sample(n=CFG.sampling_num, random_state=CFG.seed).reset_index(
drop=True
)
elif CFG.sampling_frac > 0 and CFG.sampling_frac < 1:
train = train.sample(frac=CFG.sampling_frac, random_state=CFG.seed).reset_index(
drop=True
)
if CFG.similar_reaction_data_path:
similar = preprocess_df(pd.read_csv(CFG.similar_reaction_data_path), CFG)
print(len(train))
train = pd.concat([train, similar], ignore_index=True)
print(len(train))
LOGGER = get_logger(os.path.join(CFG.output_dir, "train"))
CFG.logger = LOGGER
tokenizer = AutoTokenizer.from_pretrained(
os.path.abspath(CFG.model_name_or_path)
if os.path.exists(CFG.model_name_or_path)
else CFG.model_name_or_path,
return_tensors="pt",
)
tokenizer.save_pretrained(CFG.output_dir)
CFG.tokenizer = tokenizer
train_loop(train, valid, CFG)