sagawa's picture
Upload 42 files
08ccc8e verified
import argparse
import os
import sys
import warnings
import datasets
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
EarlyStoppingCallback,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from train import preprocess_df
from utils import filter_out, get_accuracy_score, preprocess_dataset, seed_everything
# Suppress warnings and disable progress bars
warnings.filterwarnings("ignore")
datasets.utils.logging.disable_progress_bar()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_data_path",
type=str,
required=True,
help="The path to data used for training. CSV file that contains ['REACTANT', 'PRODUCT'] columns is expected.",
)
parser.add_argument(
"--valid_data_path",
type=str,
required=True,
help="The path to data used for validation. CSV file that contains ['REACTANT', 'PRODUCT'] columns is expected.",
)
parser.add_argument(
"--similar_reaction_data_path",
type=str,
required=False,
help="Path to similar data CSV.",
)
parser.add_argument(
"--output_dir", type=str, default="t5", help="Path of the output directory."
)
parser.add_argument(
"--model_name_or_path",
type=str,
required=False,
default="sagawa/ReactionT5v2-retrosynthesis",
help="The name of a pretrained model or path to a model which you want to finetune on your dataset. You can use your local models or models uploaded to hugging face.",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
required=False,
help="Use debug mode.",
)
parser.add_argument(
"--epochs",
type=int,
default=20,
required=False,
help="Number of epochs for training.",
)
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.")
parser.add_argument(
"--input_max_length",
type=int,
default=150,
required=False,
help="Max input token length.",
)
parser.add_argument(
"--target_max_length",
type=int,
default=150,
required=False,
help="Max target token length.",
)
parser.add_argument(
"--eval_beams",
type=int,
default=5,
help="Number of beams used for beam search during evaluation.",
)
parser.add_argument(
"--target_column",
type=str,
default="REACTANT",
help="Target column name.",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.01,
required=False,
help="weight_decay used for trainer",
)
parser.add_argument(
"--evaluation_strategy",
type=str,
default="epoch",
required=False,
help="Evaluation strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --eval_steps.",
)
parser.add_argument(
"--eval_steps",
type=int,
required=False,
help="Number of update steps between two evaluations",
)
parser.add_argument(
"--save_strategy",
type=str,
default="epoch",
required=False,
help="Save strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --save_steps.",
)
parser.add_argument(
"--save_steps",
type=int,
required=False,
default=500,
help="Number of steps between two saving",
)
parser.add_argument(
"--logging_strategy",
type=str,
default="epoch",
required=False,
help="Logging strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --logging_steps.",
)
parser.add_argument(
"--logging_steps",
type=int,
required=False,
default=500,
help="Number of steps between two logging",
)
parser.add_argument(
"--save_total_limit",
type=int,
default=3,
required=False,
help="Limit of the number of saved checkpoints. If limit is reached, the oldest checkpoint will be deleted.",
)
parser.add_argument(
"--fp16",
action="store_true",
default=False,
required=False,
help="Use fp16 during training",
)
parser.add_argument(
"--disable_tqdm",
action="store_true",
default=False,
required=False,
help="Disable tqdm during training",
)
parser.add_argument(
"--seed",
type=int,
default=42,
required=False,
help="Set seed for reproducibility.",
)
parser.add_argument(
"--sampling_num",
type=int,
default=-1,
help="Number of samples used for training. If you want to use all samples, set -1.",
)
return parser.parse_args()
if __name__ == "__main__":
CFG = parse_args()
CFG.disable_tqdm = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything(seed=CFG.seed)
train = preprocess_df(
filter_out(pd.read_csv(CFG.train_data_path), ["REACTANT", "PRODUCT"])
)
valid = preprocess_df(
filter_out(pd.read_csv(CFG.valid_data_path), ["REACTANT", "PRODUCT"])
)
if CFG.sampling_num > 0:
train = train.sample(n=CFG.sampling_num, random_state=CFG.seed).reset_index(
drop=True
)
if CFG.similar_reaction_data_path:
similar = preprocess_df(
filter_out(
pd.read_csv(CFG.similar_reaction_data_path), ["REACTANT", "PRODUCT"]
)
)
print(len(train))
train = pd.concat([train, similar], ignore_index=True)
print(len(train))
dataset = DatasetDict(
{
"train": Dataset.from_pandas(train[["input", "REACTANT"]]),
"validation": Dataset.from_pandas(valid[["input", "REACTANT"]]),
}
)
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
os.path.abspath(CFG.model_name_or_path)
if os.path.exists(CFG.model_name_or_path)
else CFG.model_name_or_path,
return_tensors="pt",
)
CFG.tokenizer = tokenizer
# load model
model = AutoModelForSeq2SeqLM.from_pretrained(
os.path.abspath(CFG.model_name_or_path) if os.path.exists(CFG.model_name_or_path) else CFG.model_name_or_path
)
tokenized_datasets = dataset.map(
lambda examples: preprocess_dataset(examples, CFG),
batched=True,
remove_columns=dataset["train"].column_names,
load_from_cache_file=False,
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
args = Seq2SeqTrainingArguments(
CFG.output_dir,
evaluation_strategy=CFG.evaluation_strategy,
eval_steps=CFG.eval_steps,
save_strategy=CFG.save_strategy,
save_steps=CFG.save_steps,
logging_strategy=CFG.logging_strategy,
logging_steps=CFG.logging_steps,
learning_rate=CFG.lr,
per_device_train_batch_size=CFG.batch_size,
per_device_eval_batch_size=CFG.batch_size,
weight_decay=CFG.weight_decay,
save_total_limit=CFG.save_total_limit,
num_train_epochs=CFG.epochs,
predict_with_generate=True,
fp16=CFG.fp16,
disable_tqdm=CFG.disable_tqdm,
push_to_hub=False,
load_best_model_at_end=True,
)
model.config.eval_beams = CFG.eval_beams
model.config.max_length = CFG.target_max_length
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=lambda eval_preds: get_accuracy_score(eval_preds, CFG),
callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
)
trainer.train(resume_from_checkpoint=False)
trainer.save_model("./best_model")