ReactionT5 / utils.py
sagawa's picture
Upload 42 files
08ccc8e verified
raw
history blame
7.45 kB
import math
import os
import pickle
import random
import time
import numpy as np
import torch
from rdkit import Chem
def seed_everything(seed=42):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def space_clean(row):
row = row.replace(". ", "").replace(" .", "").replace(" ", " ")
return row
def canonicalize(smiles):
try:
new_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True)
except:
new_smiles = None
return new_smiles
def canonicalize_str(smiles):
"""Try to canonicalize the molecule, return empty string if fails."""
if "%" in smiles:
return smiles
else:
try:
return canonicalize(smiles)
except:
return ""
def uncanonicalize(smiles):
try:
new_smiles = []
for smiles_i in smiles.split("."):
mol = Chem.MolFromSmiles(smiles_i)
atom_indices = list(range(mol.GetNumAtoms()))
random.shuffle(atom_indices)
new_smiles_i = Chem.MolToSmiles(
mol, rootedAtAtom=atom_indices[0], canonical=False
)
new_smiles.append(new_smiles_i)
smiles = ".".join(new_smiles)
except:
smiles = None
return smiles
def remove_atom_mapping(smi):
mol = Chem.MolFromSmiles(smi)
[a.SetAtomMapNum(0) for a in mol.GetAtoms()]
smi = Chem.MolToSmiles(mol, canonical=True)
return canonicalize(smi)
def get_logger(filename="train"):
from logging import INFO, FileHandler, Formatter, StreamHandler, getLogger
logger = getLogger(__name__)
logger.setLevel(INFO)
handler1 = StreamHandler()
handler1.setFormatter(Formatter("%(message)s"))
handler2 = FileHandler(filename=f"{filename}.log")
handler2.setFormatter(Formatter("%(message)s"))
logger.addHandler(handler1)
logger.addHandler(handler2)
return logger
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return "%dm %ds" % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return "%s (remain %s)" % (asMinutes(s), asMinutes(rs))
def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_parameters = [
{
"params": [
p
for n, p in model.model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"lr": encoder_lr,
"weight_decay": weight_decay,
},
{
"params": [
p
for n, p in model.model.named_parameters()
if any(nd in n for nd in no_decay)
],
"lr": encoder_lr,
"weight_decay": 0.0,
},
{
"params": [p for n, p in model.named_parameters() if "model" not in n],
"lr": decoder_lr,
"weight_decay": 0.0,
},
]
return optimizer_parameters
def to_cpu(obj):
if torch.is_tensor(obj):
return obj.to("cpu")
elif isinstance(obj, dict):
return {k: to_cpu(v) for k, v in obj.items()}
elif (
isinstance(obj, list)
or isinstance(obj, tuple)
or isinstance(obj, set)
or isinstance(obj, torch.Tensor)
):
return [to_cpu(v) for v in obj]
else:
return obj
def get_accuracy_score(eval_preds, cfg):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id)
decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = [
canonicalize_str(pred.strip().replace(" ", "")) for pred in decoded_preds
]
decoded_labels = [
[canonicalize_str(label.strip().replace(" ", ""))] for label in decoded_labels
]
score = 0
for i in range(len(decoded_preds)):
if decoded_preds[i] == decoded_labels[i][0]:
score += 1
score /= len(decoded_preds)
return {"accuracy": score}
def get_accuracy_score_multitask(eval_preds, cfg):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
special_tokens = cfg.tokenizer.special_tokens_map
special_tokens = [
special_tokens["eos_token"],
special_tokens["pad_token"],
special_tokens["unk_token"],
] + list(
set(special_tokens["additional_special_tokens"])
- set(
[
"0%",
"10%",
"20%",
"30%",
"40%",
"50%",
"60%",
"70%",
"80%",
"90%",
"100%",
]
)
)
decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=False)
for special_token in special_tokens:
decoded_preds = [pred.replace(special_token, "") for pred in decoded_preds]
labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id)
decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=False)
for special_token in special_tokens:
decoded_labels = [pred.replace(special_token, "") for pred in decoded_labels]
decoded_preds = [
canonicalize_str(pred.strip().replace(" ", "")) for pred in decoded_preds
]
decoded_labels = [
[canonicalize_str(label.strip().replace(" ", ""))] for label in decoded_labels
]
score = 0
for i in range(len(decoded_preds)):
if decoded_preds[i] == decoded_labels[i][0]:
score += 1
score /= len(decoded_preds)
return {"accuracy": score}
def preprocess_dataset(examples, cfg):
inputs = examples["input"]
targets = examples[cfg.target_column]
model_inputs = cfg.tokenizer(
inputs, max_length=cfg.input_max_length, truncation=True
)
labels = cfg.tokenizer(targets, max_length=cfg.target_max_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def filter_out(df, col_names):
for col_name in col_names:
df = df[~df[col_name].isna()].reset_index(drop=True)
return df
def save_pickle(path: str, contents):
"""Saves contents to a pickle file."""
with open(path, "wb") as f:
pickle.dump(contents, f)
def load_pickle(path: str):
"""Loads contents from a pickle file."""
with open(path, "rb") as f:
return pickle.load(f)
def add_new_tokens(tokenizer, file_path):
"""
Adds new tokens to the tokenizer from a file.
The file should contain one token per line.
"""
with open(file_path, "r") as f:
new_tokens = [line.strip() for line in f if line.strip()]
tokenizer.add_tokens(new_tokens)
return tokenizer