ReactionT5 / utils.py
sagawa's picture
Upload 42 files
08ccc8e verified
import math
import os
import pickle
import random
import time
import numpy as np
import torch
from rdkit import Chem
def seed_everything(seed=42):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def space_clean(row):
row = row.replace(". ", "").replace(" .", "").replace(" ", " ")
return row
def canonicalize(smiles):
try:
new_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True)
except:
new_smiles = None
return new_smiles
def canonicalize_str(smiles):
"""Try to canonicalize the molecule, return empty string if fails."""
if "%" in smiles:
return smiles
else:
try:
return canonicalize(smiles)
except:
return ""
def uncanonicalize(smiles):
try:
new_smiles = []
for smiles_i in smiles.split("."):
mol = Chem.MolFromSmiles(smiles_i)
atom_indices = list(range(mol.GetNumAtoms()))
random.shuffle(atom_indices)
new_smiles_i = Chem.MolToSmiles(
mol, rootedAtAtom=atom_indices[0], canonical=False
)
new_smiles.append(new_smiles_i)
smiles = ".".join(new_smiles)
except:
smiles = None
return smiles
def remove_atom_mapping(smi):
mol = Chem.MolFromSmiles(smi)
[a.SetAtomMapNum(0) for a in mol.GetAtoms()]
smi = Chem.MolToSmiles(mol, canonical=True)
return canonicalize(smi)
def get_logger(filename="train"):
from logging import INFO, FileHandler, Formatter, StreamHandler, getLogger
logger = getLogger(__name__)
logger.setLevel(INFO)
handler1 = StreamHandler()
handler1.setFormatter(Formatter("%(message)s"))
handler2 = FileHandler(filename=f"{filename}.log")
handler2.setFormatter(Formatter("%(message)s"))
logger.addHandler(handler1)
logger.addHandler(handler2)
return logger
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return "%dm %ds" % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return "%s (remain %s)" % (asMinutes(s), asMinutes(rs))
def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_parameters = [
{
"params": [
p
for n, p in model.model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"lr": encoder_lr,
"weight_decay": weight_decay,
},
{
"params": [
p
for n, p in model.model.named_parameters()
if any(nd in n for nd in no_decay)
],
"lr": encoder_lr,
"weight_decay": 0.0,
},
{
"params": [p for n, p in model.named_parameters() if "model" not in n],
"lr": decoder_lr,
"weight_decay": 0.0,
},
]
return optimizer_parameters
def to_cpu(obj):
if torch.is_tensor(obj):
return obj.to("cpu")
elif isinstance(obj, dict):
return {k: to_cpu(v) for k, v in obj.items()}
elif (
isinstance(obj, list)
or isinstance(obj, tuple)
or isinstance(obj, set)
or isinstance(obj, torch.Tensor)
):
return [to_cpu(v) for v in obj]
else:
return obj
def get_accuracy_score(eval_preds, cfg):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id)
decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = [
canonicalize_str(pred.strip().replace(" ", "")) for pred in decoded_preds
]
decoded_labels = [
[canonicalize_str(label.strip().replace(" ", ""))] for label in decoded_labels
]
score = 0
for i in range(len(decoded_preds)):
if decoded_preds[i] == decoded_labels[i][0]:
score += 1
score /= len(decoded_preds)
return {"accuracy": score}
def get_accuracy_score_multitask(eval_preds, cfg):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
special_tokens = cfg.tokenizer.special_tokens_map
special_tokens = [
special_tokens["eos_token"],
special_tokens["pad_token"],
special_tokens["unk_token"],
] + list(
set(special_tokens["additional_special_tokens"])
- set(
[
"0%",
"10%",
"20%",
"30%",
"40%",
"50%",
"60%",
"70%",
"80%",
"90%",
"100%",
]
)
)
decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=False)
for special_token in special_tokens:
decoded_preds = [pred.replace(special_token, "") for pred in decoded_preds]
labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id)
decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=False)
for special_token in special_tokens:
decoded_labels = [pred.replace(special_token, "") for pred in decoded_labels]
decoded_preds = [
canonicalize_str(pred.strip().replace(" ", "")) for pred in decoded_preds
]
decoded_labels = [
[canonicalize_str(label.strip().replace(" ", ""))] for label in decoded_labels
]
score = 0
for i in range(len(decoded_preds)):
if decoded_preds[i] == decoded_labels[i][0]:
score += 1
score /= len(decoded_preds)
return {"accuracy": score}
def preprocess_dataset(examples, cfg):
inputs = examples["input"]
targets = examples[cfg.target_column]
model_inputs = cfg.tokenizer(
inputs, max_length=cfg.input_max_length, truncation=True
)
labels = cfg.tokenizer(targets, max_length=cfg.target_max_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def filter_out(df, col_names):
for col_name in col_names:
df = df[~df[col_name].isna()].reset_index(drop=True)
return df
def save_pickle(path: str, contents):
"""Saves contents to a pickle file."""
with open(path, "wb") as f:
pickle.dump(contents, f)
def load_pickle(path: str):
"""Loads contents from a pickle file."""
with open(path, "rb") as f:
return pickle.load(f)
def add_new_tokens(tokenizer, file_path):
"""
Adds new tokens to the tokenizer from a file.
The file should contain one token per line.
"""
with open(file_path, "r") as f:
new_tokens = [line.strip() for line in f if line.strip()]
tokenizer.add_tokens(new_tokens)
return tokenizer