Spaces:
Running
Running
import argparse | |
import os | |
import sys | |
import warnings | |
import datasets | |
import pandas as pd | |
import torch | |
from datasets import Dataset, DatasetDict | |
from transformers import ( | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
DataCollatorForSeq2Seq, | |
EarlyStoppingCallback, | |
Seq2SeqTrainer, | |
Seq2SeqTrainingArguments, | |
) | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from train import preprocess_df | |
from utils import filter_out, get_accuracy_score, preprocess_dataset, seed_everything | |
# Suppress warnings and disable progress bars | |
warnings.filterwarnings("ignore") | |
datasets.utils.logging.disable_progress_bar() | |
def parse_args(): | |
"""Parse command line arguments.""" | |
parser = argparse.ArgumentParser( | |
description="Training script for reaction prediction model." | |
) | |
parser.add_argument( | |
"--train_data_path", type=str, required=True, help="Path to training data CSV." | |
) | |
parser.add_argument( | |
"--valid_data_path", | |
type=str, | |
required=True, | |
help="Path to validation data CSV.", | |
) | |
parser.add_argument( | |
"--similar_reaction_data_path", | |
type=str, | |
required=False, | |
help="Path to similar data CSV.", | |
) | |
parser.add_argument( | |
"--output_dir", type=str, default="t5", help="Path of the output directory." | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
type=str, | |
default="sagawa/ReactionT5v2-forward", | |
help="The name of a pretrained model or path to a model which you want to finetune on your dataset. You can use your local models or models uploaded to hugging face.", | |
) | |
parser.add_argument( | |
"--debug", action="store_true", default=False, help="Enable debug mode." | |
) | |
parser.add_argument( | |
"--epochs", type=int, default=3, help="Number of epochs for training." | |
) | |
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate.") | |
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") | |
parser.add_argument( | |
"--input_max_length", type=int, default=200, help="Max input token length." | |
) | |
parser.add_argument( | |
"--target_max_length", type=int, default=150, help="Max target token length." | |
) | |
parser.add_argument( | |
"--eval_beams", | |
type=int, | |
default=5, | |
help="Number of beams used for beam search during evaluation.", | |
) | |
parser.add_argument( | |
"--target_column", | |
type=str, | |
default="PRODUCT", | |
help="Target column name.", | |
) | |
parser.add_argument( | |
"--weight_decay", | |
type=float, | |
default=0.01, | |
help="Weight decay.", | |
) | |
parser.add_argument( | |
"--evaluation_strategy", | |
type=str, | |
default="epoch", | |
help="Evaluation strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --eval_steps.", | |
) | |
parser.add_argument( | |
"--eval_steps", | |
type=int, | |
help="Evaluation steps.", | |
) | |
parser.add_argument( | |
"--save_strategy", | |
type=str, | |
default="epoch", | |
help="Save strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --save_steps.", | |
) | |
parser.add_argument( | |
"--save_steps", | |
type=int, | |
default=500, | |
help="Save steps.", | |
) | |
parser.add_argument( | |
"--logging_strategy", | |
type=str, | |
default="epoch", | |
help="Logging strategy used during training. Select from 'no', 'steps', or 'epoch'. If you select 'steps', also give --logging_steps.", | |
) | |
parser.add_argument( | |
"--logging_steps", | |
type=int, | |
default=500, | |
help="Logging steps.", | |
) | |
parser.add_argument( | |
"--save_total_limit", | |
type=int, | |
default=2, | |
help="Limit of saved checkpoints.", | |
) | |
parser.add_argument( | |
"--fp16", | |
action="store_true", | |
default=False, | |
help="Enable fp16 training.", | |
) | |
parser.add_argument( | |
"--disable_tqdm", | |
action="store_true", | |
default=False, | |
help="Disable tqdm.", | |
) | |
parser.add_argument( | |
"--seed", type=int, default=42, help="Set seed for reproducibility." | |
) | |
parser.add_argument( | |
"--sampling_num", | |
type=int, | |
default=-1, | |
help="Number of samples used for training. If you want to use all samples, set -1.", | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
CFG = parse_args() | |
CFG.disable_tqdm = True | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
seed_everything(seed=CFG.seed) | |
train = preprocess_df( | |
filter_out(pd.read_csv(CFG.train_data_path), ["REACTANT", "PRODUCT"]) | |
) | |
valid = preprocess_df( | |
filter_out(pd.read_csv(CFG.valid_data_path), ["REACTANT", "PRODUCT"]) | |
) | |
if CFG.sampling_num > 0: | |
train = train.sample(n=CFG.sampling_num, random_state=CFG.seed).reset_index( | |
drop=True | |
) | |
if CFG.similar_reaction_data_path: | |
similar = preprocess_df( | |
filter_out( | |
pd.read_csv(CFG.similar_reaction_data_path), ["REACTANT", "PRODUCT"] | |
) | |
) | |
print(len(train)) | |
train = pd.concat([train, similar], ignore_index=True) | |
print(len(train)) | |
for col in ["REAGENT"]: | |
train[col] = train[col].fillna(" ") | |
valid[col] = valid[col].fillna(" ") | |
train["input"] = "REACTANT:" + train["REACTANT"] + "REAGENT:" + train["REAGENT"] | |
valid["input"] = "REACTANT:" + valid["REACTANT"] + "REAGENT:" + valid["REAGENT"] | |
if CFG.debug: | |
train = train[: int(len(train) / 40)].reset_index(drop=True) | |
valid = valid[: int(len(valid) / 40)].reset_index(drop=True) | |
dataset = DatasetDict( | |
{ | |
"train": Dataset.from_pandas(train[["input", "PRODUCT"]]), | |
"validation": Dataset.from_pandas(valid[["input", "PRODUCT"]]), | |
} | |
) | |
# load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) | |
if os.path.exists(CFG.model_name_or_path) | |
else CFG.model_name_or_path, | |
return_tensors="pt", | |
) | |
CFG.tokenizer = tokenizer | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
os.path.abspath(CFG.model_name_or_path) if os.path.exists(CFG.model_name_or_path) else CFG.model_name_or_path | |
).to(device) | |
tokenized_datasets = dataset.map( | |
lambda examples: preprocess_dataset(examples, CFG), | |
batched=True, | |
remove_columns=dataset["train"].column_names, | |
) | |
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
args = Seq2SeqTrainingArguments( | |
CFG.output_dir, | |
evaluation_strategy=CFG.evaluation_strategy, | |
save_strategy=CFG.save_strategy, | |
logging_strategy=CFG.logging_strategy, | |
learning_rate=CFG.lr, | |
per_device_train_batch_size=CFG.batch_size, | |
per_device_eval_batch_size=CFG.batch_size * 4, | |
weight_decay=CFG.weight_decay, | |
save_total_limit=CFG.save_total_limit, | |
num_train_epochs=CFG.epochs, | |
predict_with_generate=True, | |
fp16=CFG.fp16, | |
disable_tqdm=CFG.disable_tqdm, | |
push_to_hub=False, | |
load_best_model_at_end=True, | |
) | |
model.config.eval_beams = CFG.eval_beams | |
model.config.max_length = CFG.target_max_length | |
trainer = Seq2SeqTrainer( | |
model, | |
args, | |
train_dataset=tokenized_datasets["train"], | |
eval_dataset=tokenized_datasets["validation"], | |
data_collator=data_collator, | |
tokenizer=tokenizer, | |
compute_metrics=lambda eval_preds: get_accuracy_score(eval_preds, CFG), | |
callbacks=[EarlyStoppingCallback(early_stopping_patience=10)], | |
) | |
trainer.train() | |
trainer.save_model("./best_model") | |