Spaces:
Running
Running
import math | |
import os | |
import pickle | |
import random | |
import time | |
import numpy as np | |
import torch | |
from rdkit import Chem | |
def seed_everything(seed=42): | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def space_clean(row): | |
row = row.replace(". ", "").replace(" .", "").replace(" ", " ") | |
return row | |
def canonicalize(smiles): | |
try: | |
new_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True) | |
except: | |
new_smiles = None | |
return new_smiles | |
def canonicalize_str(smiles): | |
"""Try to canonicalize the molecule, return empty string if fails.""" | |
if "%" in smiles: | |
return smiles | |
else: | |
try: | |
return canonicalize(smiles) | |
except: | |
return "" | |
def uncanonicalize(smiles): | |
try: | |
new_smiles = [] | |
for smiles_i in smiles.split("."): | |
mol = Chem.MolFromSmiles(smiles_i) | |
atom_indices = list(range(mol.GetNumAtoms())) | |
random.shuffle(atom_indices) | |
new_smiles_i = Chem.MolToSmiles( | |
mol, rootedAtAtom=atom_indices[0], canonical=False | |
) | |
new_smiles.append(new_smiles_i) | |
smiles = ".".join(new_smiles) | |
except: | |
smiles = None | |
return smiles | |
def remove_atom_mapping(smi): | |
mol = Chem.MolFromSmiles(smi) | |
[a.SetAtomMapNum(0) for a in mol.GetAtoms()] | |
smi = Chem.MolToSmiles(mol, canonical=True) | |
return canonicalize(smi) | |
def get_logger(filename="train"): | |
from logging import INFO, FileHandler, Formatter, StreamHandler, getLogger | |
logger = getLogger(__name__) | |
logger.setLevel(INFO) | |
handler1 = StreamHandler() | |
handler1.setFormatter(Formatter("%(message)s")) | |
handler2 = FileHandler(filename=f"{filename}.log") | |
handler2.setFormatter(Formatter("%(message)s")) | |
logger.addHandler(handler1) | |
logger.addHandler(handler2) | |
return logger | |
class AverageMeter(object): | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def asMinutes(s): | |
m = math.floor(s / 60) | |
s -= m * 60 | |
return "%dm %ds" % (m, s) | |
def timeSince(since, percent): | |
now = time.time() | |
s = now - since | |
es = s / (percent) | |
rs = es - s | |
return "%s (remain %s)" % (asMinutes(s), asMinutes(rs)) | |
def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0): | |
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] | |
optimizer_parameters = [ | |
{ | |
"params": [ | |
p | |
for n, p in model.model.named_parameters() | |
if not any(nd in n for nd in no_decay) | |
], | |
"lr": encoder_lr, | |
"weight_decay": weight_decay, | |
}, | |
{ | |
"params": [ | |
p | |
for n, p in model.model.named_parameters() | |
if any(nd in n for nd in no_decay) | |
], | |
"lr": encoder_lr, | |
"weight_decay": 0.0, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() if "model" not in n], | |
"lr": decoder_lr, | |
"weight_decay": 0.0, | |
}, | |
] | |
return optimizer_parameters | |
def to_cpu(obj): | |
if torch.is_tensor(obj): | |
return obj.to("cpu") | |
elif isinstance(obj, dict): | |
return {k: to_cpu(v) for k, v in obj.items()} | |
elif ( | |
isinstance(obj, list) | |
or isinstance(obj, tuple) | |
or isinstance(obj, set) | |
or isinstance(obj, torch.Tensor) | |
): | |
return [to_cpu(v) for v in obj] | |
else: | |
return obj | |
def get_accuracy_score(eval_preds, cfg): | |
preds, labels = eval_preds | |
if isinstance(preds, tuple): | |
preds = preds[0] | |
decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=True) | |
labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id) | |
decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=True) | |
decoded_preds = [ | |
canonicalize_str(pred.strip().replace(" ", "")) for pred in decoded_preds | |
] | |
decoded_labels = [ | |
[canonicalize_str(label.strip().replace(" ", ""))] for label in decoded_labels | |
] | |
score = 0 | |
for i in range(len(decoded_preds)): | |
if decoded_preds[i] == decoded_labels[i][0]: | |
score += 1 | |
score /= len(decoded_preds) | |
return {"accuracy": score} | |
def get_accuracy_score_multitask(eval_preds, cfg): | |
preds, labels = eval_preds | |
if isinstance(preds, tuple): | |
preds = preds[0] | |
special_tokens = cfg.tokenizer.special_tokens_map | |
special_tokens = [ | |
special_tokens["eos_token"], | |
special_tokens["pad_token"], | |
special_tokens["unk_token"], | |
] + list( | |
set(special_tokens["additional_special_tokens"]) | |
- set( | |
[ | |
"0%", | |
"10%", | |
"20%", | |
"30%", | |
"40%", | |
"50%", | |
"60%", | |
"70%", | |
"80%", | |
"90%", | |
"100%", | |
] | |
) | |
) | |
decoded_preds = cfg.tokenizer.batch_decode(preds, skip_special_tokens=False) | |
for special_token in special_tokens: | |
decoded_preds = [pred.replace(special_token, "") for pred in decoded_preds] | |
labels = np.where(labels != -100, labels, cfg.tokenizer.pad_token_id) | |
decoded_labels = cfg.tokenizer.batch_decode(labels, skip_special_tokens=False) | |
for special_token in special_tokens: | |
decoded_labels = [pred.replace(special_token, "") for pred in decoded_labels] | |
decoded_preds = [ | |
canonicalize_str(pred.strip().replace(" ", "")) for pred in decoded_preds | |
] | |
decoded_labels = [ | |
[canonicalize_str(label.strip().replace(" ", ""))] for label in decoded_labels | |
] | |
score = 0 | |
for i in range(len(decoded_preds)): | |
if decoded_preds[i] == decoded_labels[i][0]: | |
score += 1 | |
score /= len(decoded_preds) | |
return {"accuracy": score} | |
def preprocess_dataset(examples, cfg): | |
inputs = examples["input"] | |
targets = examples[cfg.target_column] | |
model_inputs = cfg.tokenizer( | |
inputs, max_length=cfg.input_max_length, truncation=True | |
) | |
labels = cfg.tokenizer(targets, max_length=cfg.target_max_length, truncation=True) | |
model_inputs["labels"] = labels["input_ids"] | |
return model_inputs | |
def filter_out(df, col_names): | |
for col_name in col_names: | |
df = df[~df[col_name].isna()].reset_index(drop=True) | |
return df | |
def save_pickle(path: str, contents): | |
"""Saves contents to a pickle file.""" | |
with open(path, "wb") as f: | |
pickle.dump(contents, f) | |
def load_pickle(path: str): | |
"""Loads contents from a pickle file.""" | |
with open(path, "rb") as f: | |
return pickle.load(f) | |
def add_new_tokens(tokenizer, file_path): | |
""" | |
Adds new tokens to the tokenizer from a file. | |
The file should contain one token per line. | |
""" | |
with open(file_path, "r") as f: | |
new_tokens = [line.strip() for line in f if line.strip()] | |
tokenizer.add_tokens(new_tokens) | |
return tokenizer | |