|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import Transformer |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn.utils.rnn import pad_sequence |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning.callbacks import ( |
|
ModelCheckpoint, |
|
EarlyStopping, |
|
) |
|
import math |
|
import os |
|
import pandas as pd |
|
from sklearn.model_selection import train_test_split |
|
import time |
|
import wandb |
|
|
|
|
|
from tokenizers import ( |
|
Tokenizer, |
|
models, |
|
pre_tokenizers, |
|
decoders, |
|
trainers, |
|
) |
|
|
|
import logging |
|
import gc |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
SRC_VOCAB_SIZE_ESTIMATE = 10000 |
|
TGT_VOCAB_SIZE_ESTIMATE = 14938 |
|
EMB_SIZE = 2048 |
|
NHEAD = 8 |
|
FFN_HID_DIM = ( |
|
4096 |
|
) |
|
NUM_ENCODER_LAYERS = 12 |
|
NUM_DECODER_LAYERS = 12 |
|
DROPOUT = 0.1 |
|
MAX_LEN = 384 |
|
|
|
|
|
ACCELERATOR = "gpu" |
|
DEVICES = 6 |
|
STRATEGY = "ddp" |
|
PRECISION = "16-mixed" |
|
BATCH_SIZE_PER_GPU = 48 |
|
ACCUMULATE_GRAD_BATCHES = ( |
|
1 |
|
) |
|
NUM_EPOCHS = 50 |
|
LEARNING_RATE = 5e-5 |
|
WEIGHT_DECAY = 1e-2 |
|
GRAD_CLIP_NORM = 1.0 |
|
VALIDATION_SPLIT = 0.05 |
|
RANDOM_SEED = 42 |
|
PATIENCE = 5 |
|
NUM_WORKERS = 8 |
|
|
|
|
|
PAD_IDX = 0 |
|
SOS_IDX = 1 |
|
EOS_IDX = 2 |
|
UNK_IDX = 3 |
|
|
|
|
|
|
|
SMILES_TOKENIZER_FILE = "smiles_bytelevel_bpe_tokenizer_scaled.json" |
|
IUPAC_TOKENIZER_FILE = "iupac_unigram_tokenizer_scaled.json" |
|
INPUT_CSV_FILE = "data_clean.csv" |
|
|
|
|
|
TRAIN_SMILES_FILE = "train.smi" |
|
TRAIN_IUPAC_FILE = "train.iupac" |
|
VAL_SMILES_FILE = "val.smi" |
|
VAL_IUPAC_FILE = "val.iupac" |
|
CHECKPOINT_DIR = "checkpoints" |
|
BEST_MODEL_FILENAME = ( |
|
"smiles-to-iupac-transformer-best" |
|
) |
|
|
|
|
|
WANDB_PROJECT = "SMILES-to-IUPAC-Large-BPE" |
|
WANDB_ENTITY = ( |
|
"adrianmirza" |
|
) |
|
WANDB_RUN_NAME = f"transformer_BPE_E{EMB_SIZE}_H{NHEAD}_L{NUM_ENCODER_LAYERS}_BS{BATCH_SIZE_PER_GPU * DEVICES}_LR{LEARNING_RATE}" |
|
|
|
|
|
hparams = { |
|
"src_tokenizer_type": "ByteLevelBPE", |
|
"tgt_tokenizer_type": "Unigram", |
|
"src_vocab_size_estimate": SRC_VOCAB_SIZE_ESTIMATE, |
|
"tgt_vocab_size_estimate": TGT_VOCAB_SIZE_ESTIMATE, |
|
"emb_size": EMB_SIZE, |
|
"nhead": NHEAD, |
|
"ffn_hid_dim": FFN_HID_DIM, |
|
"num_encoder_layers": NUM_ENCODER_LAYERS, |
|
"num_decoder_layers": NUM_DECODER_LAYERS, |
|
"dropout": DROPOUT, |
|
"max_len": MAX_LEN, |
|
"batch_size_per_gpu": BATCH_SIZE_PER_GPU, |
|
"effective_batch_size": BATCH_SIZE_PER_GPU * DEVICES * ACCUMULATE_GRAD_BATCHES, |
|
"num_epochs": NUM_EPOCHS, |
|
"learning_rate": LEARNING_RATE, |
|
"weight_decay": WEIGHT_DECAY, |
|
"grad_clip_norm": GRAD_CLIP_NORM, |
|
"validation_split": VALIDATION_SPLIT, |
|
"random_seed": RANDOM_SEED, |
|
"patience": PATIENCE, |
|
"precision": PRECISION, |
|
"gpus": DEVICES, |
|
"strategy": STRATEGY, |
|
"num_workers": NUM_WORKERS, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def get_smiles_tokenizer( |
|
train_files=None, |
|
vocab_size=30000, |
|
min_frequency=2, |
|
tokenizer_path=SMILES_TOKENIZER_FILE, |
|
): |
|
"""Creates or loads a Byte-Level BPE tokenizer for SMILES.""" |
|
if os.path.exists(tokenizer_path): |
|
logging.info(f"Loading existing SMILES tokenizer from {tokenizer_path}") |
|
try: |
|
tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
|
if ( |
|
tokenizer.token_to_id("<pad>") != PAD_IDX |
|
or tokenizer.token_to_id("<sos>") != SOS_IDX |
|
or tokenizer.token_to_id("<eos>") != EOS_IDX |
|
or tokenizer.token_to_id("<unk>") != UNK_IDX |
|
): |
|
logging.warning( |
|
"Special token ID mismatch after loading SMILES tokenizer. Re-check config." |
|
) |
|
|
|
if not isinstance(tokenizer.model, models.BPE): |
|
logging.warning( |
|
f"Loaded tokenizer from {tokenizer_path} is not a BPE model. Retraining." |
|
) |
|
raise TypeError("Incorrect tokenizer model type loaded.") |
|
return tokenizer |
|
except Exception as e: |
|
logging.error(f"Failed to load SMILES tokenizer: {e}. Retraining...") |
|
|
|
logging.info("Creating and training SMILES Byte-Level BPE tokenizer...") |
|
|
|
tokenizer = Tokenizer(models.BPE(unk_token="<unk>")) |
|
|
|
|
|
|
|
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) |
|
|
|
tokenizer.decoder = decoders.ByteLevel() |
|
|
|
special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"] |
|
|
|
trainer = trainers.BpeTrainer( |
|
vocab_size=vocab_size, |
|
min_frequency=min_frequency, |
|
special_tokens=special_tokens, |
|
|
|
|
|
|
|
) |
|
|
|
if train_files and all(os.path.exists(f) for f in train_files): |
|
logging.info(f"Training SMILES BPE tokenizer on: {train_files}") |
|
tokenizer.train(files=train_files, trainer=trainer) |
|
logging.info( |
|
f"SMILES BPE tokenizer trained. Final Vocab size: {tokenizer.get_vocab_size()}" |
|
) |
|
|
|
if ( |
|
tokenizer.token_to_id("<pad>") != PAD_IDX |
|
or tokenizer.token_to_id("<sos>") != SOS_IDX |
|
or tokenizer.token_to_id("<eos>") != EOS_IDX |
|
or tokenizer.token_to_id("<unk>") != UNK_IDX |
|
): |
|
logging.warning( |
|
"Special token ID mismatch after training SMILES BPE tokenizer. Check trainer setup." |
|
) |
|
try: |
|
tokenizer.save(tokenizer_path) |
|
logging.info(f"SMILES BPE tokenizer saved to {tokenizer_path}") |
|
except Exception as e: |
|
logging.error(f"Failed to save SMILES BPE tokenizer: {e}") |
|
else: |
|
logging.error( |
|
"Training files not provided or not found for SMILES tokenizer. Cannot train." |
|
) |
|
|
|
tokenizer.add_special_tokens(special_tokens) |
|
|
|
return tokenizer |
|
|
|
|
|
|
|
def get_iupac_tokenizer( |
|
train_files=None, |
|
vocab_size=30000, |
|
min_frequency=2, |
|
tokenizer_path=IUPAC_TOKENIZER_FILE, |
|
): |
|
"""Creates or loads a Unigram tokenizer for IUPAC names.""" |
|
if os.path.exists(tokenizer_path): |
|
logging.info(f"Loading existing IUPAC tokenizer from {tokenizer_path}") |
|
try: |
|
tokenizer = Tokenizer.from_file(tokenizer_path) |
|
if ( |
|
tokenizer.token_to_id("<pad>") != PAD_IDX |
|
or tokenizer.token_to_id("<sos>") != SOS_IDX |
|
or tokenizer.token_to_id("<eos>") != EOS_IDX |
|
or tokenizer.token_to_id("<unk>") != UNK_IDX |
|
): |
|
logging.warning( |
|
"Special token ID mismatch after loading IUPAC tokenizer. Re-check config." |
|
) |
|
return tokenizer |
|
except Exception as e: |
|
logging.error(f"Failed to load IUPAC tokenizer: {e}. Retraining...") |
|
|
|
logging.info("Creating and training IUPAC Unigram tokenizer...") |
|
tokenizer = Tokenizer(models.Unigram()) |
|
|
|
pre_tokenizer_list = [ |
|
pre_tokenizers.WhitespaceSplit(), |
|
pre_tokenizers.Punctuation(), |
|
pre_tokenizers.Digits(individual_digits=True), |
|
] |
|
|
|
|
|
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(pre_tokenizer_list) |
|
tokenizer.decoder = ( |
|
decoders.Metaspace() |
|
) |
|
special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"] |
|
trainer = trainers.UnigramTrainer( |
|
vocab_size=vocab_size, |
|
special_tokens=special_tokens, |
|
unk_token="<unk>", |
|
|
|
|
|
|
|
) |
|
|
|
if train_files and all(os.path.exists(f) for f in train_files): |
|
logging.info(f"Training IUPAC tokenizer on: {train_files}") |
|
tokenizer.train(files=train_files, trainer=trainer) |
|
logging.info( |
|
f"IUPAC tokenizer trained. Final Vocab size: {tokenizer.get_vocab_size()}" |
|
) |
|
|
|
if ( |
|
tokenizer.token_to_id("<pad>") != PAD_IDX |
|
or tokenizer.token_to_id("<sos>") != SOS_IDX |
|
or tokenizer.token_to_id("<eos>") != EOS_IDX |
|
or tokenizer.token_to_id("<unk>") != UNK_IDX |
|
): |
|
logging.warning( |
|
"Special token ID mismatch after training IUPAC tokenizer. Check trainer setup." |
|
) |
|
try: |
|
tokenizer.save(tokenizer_path) |
|
logging.info(f"IUPAC tokenizer saved to {tokenizer_path}") |
|
except Exception as e: |
|
logging.error(f"Failed to save IUPAC tokenizer: {e}") |
|
else: |
|
logging.error( |
|
"Training files not provided or not found for IUPAC tokenizer. Cannot train." |
|
) |
|
tokenizer.add_special_tokens(special_tokens) |
|
|
|
return tokenizer |
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
"""Injects positional information into the input embeddings.""" |
|
|
|
def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000): |
|
super().__init__() |
|
den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) |
|
pos = torch.arange(0, maxlen).reshape(maxlen, 1) |
|
pos_embedding = torch.zeros((maxlen, emb_size)) |
|
pos_embedding[:, 0::2] = torch.sin(pos * den) |
|
pos_embedding[:, 1::2] = torch.cos(pos * den) |
|
pos_embedding = pos_embedding.unsqueeze( |
|
0 |
|
) |
|
self.dropout = nn.Dropout(dropout) |
|
self.register_buffer( |
|
"pos_embedding", pos_embedding |
|
) |
|
|
|
def forward(self, token_embedding: torch.Tensor): |
|
|
|
seq_len = token_embedding.size(1) |
|
|
|
|
|
if seq_len > self.pos_embedding.size(1): |
|
logging.warning( |
|
f"Input sequence length ({seq_len}) exceeds PositionalEncoding maxlen ({self.pos_embedding.size(1)}). Truncating positional encoding." |
|
) |
|
pos_to_add = self.pos_embedding[:, : self.pos_embedding.size(1), :] |
|
|
|
|
|
output = token_embedding[:, : self.pos_embedding.size(1), :] + pos_to_add |
|
else: |
|
pos_to_add = self.pos_embedding[:, :seq_len, :] |
|
output = token_embedding + pos_to_add |
|
|
|
return self.dropout(output) |
|
|
|
|
|
class TokenEmbedding(nn.Module): |
|
"""Converts token indices to embeddings.""" |
|
|
|
def __init__(self, vocab_size: int, emb_size): |
|
super().__init__() |
|
self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=PAD_IDX) |
|
self.emb_size = emb_size |
|
|
|
def forward(self, tokens: torch.Tensor): |
|
return self.embedding(tokens.long()) * math.sqrt(self.emb_size) |
|
|
|
|
|
class Seq2SeqTransformer(nn.Module): |
|
"""The main Encoder-Decoder Transformer model.""" |
|
|
|
def __init__( |
|
self, |
|
num_encoder_layers: int, |
|
num_decoder_layers: int, |
|
emb_size: int, |
|
nhead: int, |
|
src_vocab_size: int, |
|
tgt_vocab_size: int, |
|
dim_feedforward: int, |
|
dropout: float = 0.1, |
|
max_len: int = MAX_LEN, |
|
): |
|
super().__init__() |
|
|
|
if emb_size % nhead != 0: |
|
raise ValueError( |
|
f"Embedding size ({emb_size}) must be divisible by the number of heads ({nhead})" |
|
) |
|
|
|
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) |
|
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) |
|
|
|
|
|
pe_maxlen = max( |
|
max_len, 5000 |
|
) |
|
self.positional_encoding = PositionalEncoding( |
|
emb_size, dropout=dropout, maxlen=pe_maxlen |
|
) |
|
|
|
self.transformer = Transformer( |
|
d_model=emb_size, |
|
nhead=nhead, |
|
num_encoder_layers=num_encoder_layers, |
|
num_decoder_layers=num_decoder_layers, |
|
dim_feedforward=dim_feedforward, |
|
dropout=dropout, |
|
batch_first=True, |
|
) |
|
|
|
self.generator = nn.Linear(emb_size, tgt_vocab_size) |
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward( |
|
self, |
|
src: torch.Tensor, |
|
trg: torch.Tensor, |
|
tgt_mask: torch.Tensor, |
|
src_padding_mask: torch.Tensor, |
|
tgt_padding_mask: torch.Tensor, |
|
memory_key_padding_mask: torch.Tensor, |
|
): |
|
|
|
|
|
src_padding_mask = src_padding_mask.to(src.device) |
|
tgt_padding_mask = tgt_padding_mask.to(trg.device) |
|
memory_key_padding_mask = memory_key_padding_mask.to(src.device) |
|
|
|
tgt_mask = tgt_mask.to(trg.device) |
|
|
|
src_emb = self.positional_encoding( |
|
self.src_tok_emb(src) |
|
) |
|
tgt_emb = self.positional_encoding( |
|
self.tgt_tok_emb(trg) |
|
) |
|
|
|
outs = self.transformer( |
|
src=src_emb, |
|
tgt=tgt_emb, |
|
src_mask=None, |
|
tgt_mask=tgt_mask, |
|
memory_mask=None, |
|
src_key_padding_mask=src_padding_mask, |
|
tgt_key_padding_mask=tgt_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
) |
|
|
|
return self.generator(outs) |
|
|
|
def encode(self, src: torch.Tensor, src_padding_mask: torch.Tensor): |
|
src_padding_mask = src_padding_mask.to( |
|
src.device |
|
) |
|
src_emb = self.positional_encoding( |
|
self.src_tok_emb(src) |
|
) |
|
memory = self.transformer.encoder( |
|
src_emb, mask=None, src_key_padding_mask=src_padding_mask |
|
) |
|
return memory |
|
|
|
def decode( |
|
self, |
|
tgt: torch.Tensor, |
|
memory: torch.Tensor, |
|
tgt_mask: torch.Tensor, |
|
tgt_padding_mask: torch.Tensor, |
|
memory_key_padding_mask: torch.Tensor, |
|
): |
|
|
|
tgt_mask = tgt_mask.to(tgt.device) |
|
tgt_padding_mask = tgt_padding_mask.to(tgt.device) |
|
memory_key_padding_mask = memory_key_padding_mask.to(memory.device) |
|
|
|
tgt_emb = self.positional_encoding( |
|
self.tgt_tok_emb(tgt) |
|
) |
|
output = self.transformer.decoder( |
|
tgt=tgt_emb, |
|
memory=memory, |
|
tgt_mask=tgt_mask, |
|
memory_mask=None, |
|
tgt_key_padding_mask=tgt_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
) |
|
return output |
|
|
|
|
|
|
|
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor: |
|
"""Generates an upper-triangular matrix for causal masking.""" |
|
mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1) |
|
mask = ( |
|
mask.float() |
|
.masked_fill(mask == 0, float("-inf")) |
|
.masked_fill(mask == 1, float(0.0)) |
|
) |
|
return mask |
|
|
|
|
|
def create_masks( |
|
src: torch.Tensor, tgt: torch.Tensor, pad_idx: int, device: torch.device |
|
): |
|
""" |
|
Creates all necessary masks for the Transformer model. |
|
Assumes src and tgt are inputs to the forward pass (tgt includes SOS, excludes EOS). |
|
Returns boolean masks where True indicates the position should be masked (ignored). |
|
""" |
|
src_seq_len = src.shape[1] |
|
tgt_seq_len = tgt.shape[1] |
|
|
|
|
|
tgt_mask = generate_square_subsequent_mask( |
|
tgt_seq_len, device |
|
) |
|
|
|
|
|
src_padding_mask = src == pad_idx |
|
tgt_padding_mask = tgt == pad_idx |
|
memory_key_padding_mask = ( |
|
src_padding_mask |
|
) |
|
|
|
return tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask |
|
|
|
|
|
|
|
class SmilesIupacDataset(Dataset): |
|
"""Dataset class for SMILES-IUPAC pairs, reading from pre-split files.""" |
|
|
|
def __init__(self, smiles_file: str, iupac_file: str): |
|
logging.info(f"Loading data from {smiles_file} and {iupac_file}") |
|
try: |
|
with open(smiles_file, "r", encoding="utf-8") as f_smi: |
|
self.smiles = [line.strip() for line in f_smi if line.strip()] |
|
with open(iupac_file, "r", encoding="utf-8") as f_iupac: |
|
self.iupac = [line.strip() for line in f_iupac if line.strip()] |
|
|
|
if len(self.smiles) != len(self.iupac): |
|
logging.warning( |
|
f"Mismatch in number of lines: {smiles_file} ({len(self.smiles)}) vs {iupac_file} ({len(self.iupac)}). Trimming." |
|
) |
|
min_len = min(len(self.smiles), len(self.iupac)) |
|
self.smiles = self.smiles[:min_len] |
|
self.iupac = self.iupac[:min_len] |
|
|
|
logging.info( |
|
f"Loaded {len(self.smiles)} pairs from {smiles_file}/{iupac_file}." |
|
) |
|
if len(self.smiles) == 0: |
|
logging.warning(f"Loaded 0 data pairs. Check files.") |
|
|
|
except FileNotFoundError: |
|
logging.error( |
|
f"Error: One or both files not found: {smiles_file}, {iupac_file}" |
|
) |
|
raise |
|
except Exception as e: |
|
logging.error(f"Error loading data: {e}") |
|
raise |
|
|
|
def __len__(self): |
|
return len(self.smiles) |
|
|
|
def __getitem__(self, idx): |
|
return self.smiles[idx], self.iupac[idx] |
|
|
|
|
|
def collate_fn( |
|
batch, smiles_tokenizer, iupac_tokenizer, pad_idx, sos_idx, eos_idx, max_len |
|
): |
|
"""Collates data samples into batches.""" |
|
src_batch, tgt_batch = [], [] |
|
skipped_count = 0 |
|
for src_sample, tgt_sample in batch: |
|
try: |
|
|
|
src_encoded = smiles_tokenizer.encode(src_sample) |
|
|
|
src_ids = src_encoded.ids[:max_len] |
|
if not src_ids: |
|
skipped_count += 1 |
|
continue |
|
src_tensor = torch.tensor(src_ids, dtype=torch.long) |
|
|
|
|
|
tgt_encoded = iupac_tokenizer.encode(tgt_sample) |
|
|
|
tgt_ids = tgt_encoded.ids[: max_len - 2] |
|
if ( |
|
not tgt_ids |
|
): |
|
skipped_count += 1 |
|
continue |
|
|
|
tgt_tensor = torch.tensor([sos_idx] + tgt_ids + [eos_idx], dtype=torch.long) |
|
|
|
src_batch.append(src_tensor) |
|
tgt_batch.append(tgt_tensor) |
|
except Exception as e: |
|
|
|
|
|
|
|
skipped_count += 1 |
|
continue |
|
|
|
|
|
|
|
|
|
if not src_batch or not tgt_batch: |
|
|
|
return torch.tensor([]), torch.tensor([]) |
|
|
|
try: |
|
|
|
src_batch_padded = pad_sequence( |
|
src_batch, batch_first=True, padding_value=pad_idx |
|
) |
|
tgt_batch_padded = pad_sequence( |
|
tgt_batch, batch_first=True, padding_value=pad_idx |
|
) |
|
except Exception as e: |
|
logging.error( |
|
f"Error during padding: {e}. Src lengths: {[len(s) for s in src_batch]}, Tgt lengths: {[len(t) for t in tgt_batch]}" |
|
) |
|
|
|
return torch.tensor([]), torch.tensor([]) |
|
|
|
return src_batch_padded, tgt_batch_padded |
|
|
|
|
|
|
|
class SmilesIupacLitModule(pl.LightningModule): |
|
def __init__( |
|
self, src_vocab_size: int, tgt_vocab_size: int, hparams_dict: dict |
|
): |
|
super().__init__() |
|
|
|
|
|
self.save_hyperparameters(hparams_dict) |
|
|
|
self.model = Seq2SeqTransformer( |
|
num_encoder_layers=self.hparams.num_encoder_layers, |
|
num_decoder_layers=self.hparams.num_decoder_layers, |
|
emb_size=self.hparams.emb_size, |
|
nhead=self.hparams.nhead, |
|
src_vocab_size=src_vocab_size, |
|
tgt_vocab_size=tgt_vocab_size, |
|
dim_feedforward=self.hparams.ffn_hid_dim, |
|
dropout=self.hparams.dropout, |
|
max_len=self.hparams.max_len, |
|
) |
|
|
|
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) |
|
|
|
|
|
total_params = sum(p.numel() for p in self.model.parameters()) |
|
trainable_params = sum( |
|
p.numel() for p in self.model.parameters() if p.requires_grad |
|
) |
|
logging.info(f"Model Initialized:") |
|
logging.info(f" Total Parameters: {total_params / 1_000_000:.2f} M") |
|
logging.info(f" Trainable Parameters: {trainable_params / 1_000_000:.2f} M") |
|
|
|
|
|
|
|
self.hparams.total_params_M = round(total_params / 1_000_000, 2) |
|
self.hparams.trainable_params_M = round(trainable_params / 1_000_000, 2) |
|
|
|
def forward(self, src, tgt): |
|
|
|
|
|
|
|
tgt_input = tgt[:, :-1] |
|
tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = ( |
|
create_masks( |
|
src, |
|
tgt_input, |
|
PAD_IDX, |
|
self.device, |
|
) |
|
) |
|
logits = self.model( |
|
src, |
|
tgt_input, |
|
tgt_mask, |
|
src_padding_mask, |
|
tgt_padding_mask, |
|
memory_key_padding_mask, |
|
) |
|
return logits |
|
|
|
def training_step(self, batch, batch_idx): |
|
src, tgt = batch |
|
if src.numel() == 0 or tgt.numel() == 0: |
|
|
|
return None |
|
|
|
tgt_input = tgt[:, :-1] |
|
tgt_out = tgt[:, 1:] |
|
|
|
|
|
tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = ( |
|
create_masks(src, tgt_input, PAD_IDX, self.device) |
|
) |
|
|
|
try: |
|
logits = self.model( |
|
src=src, |
|
trg=tgt_input, |
|
tgt_mask=tgt_mask, |
|
src_padding_mask=src_padding_mask, |
|
tgt_padding_mask=tgt_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
loss = self.criterion( |
|
logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) |
|
) |
|
|
|
|
|
if not torch.isfinite(loss): |
|
logging.warning( |
|
f"Non-finite loss encountered in training step {batch_idx}: {loss.item()}. Skipping update." |
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
self.log( |
|
"train_loss", |
|
loss, |
|
on_step=True, |
|
on_epoch=True, |
|
prog_bar=True, |
|
logger=True, |
|
sync_dist=True, |
|
batch_size=src.size(0), |
|
) |
|
|
|
return loss |
|
|
|
except RuntimeError as e: |
|
if "CUDA out of memory" in str(e): |
|
logging.warning( |
|
f"CUDA OOM error during training step {batch_idx} with shape src: {src.shape}, tgt: {tgt.shape}. Skipping batch." |
|
) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return None |
|
else: |
|
logging.error(f"Runtime error during training step {batch_idx}: {e}") |
|
|
|
logging.error(f"Shapes - src: {src.shape}, tgt: {tgt.shape}") |
|
return None |
|
|
|
def validation_step(self, batch, batch_idx): |
|
src, tgt = batch |
|
if src.numel() == 0 or tgt.numel() == 0: |
|
|
|
return None |
|
|
|
tgt_input = tgt[:, :-1] |
|
tgt_out = tgt[:, 1:] |
|
|
|
tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = ( |
|
create_masks(src, tgt_input, PAD_IDX, self.device) |
|
) |
|
|
|
try: |
|
logits = self.model( |
|
src, |
|
tgt_input, |
|
tgt_mask, |
|
src_padding_mask, |
|
tgt_padding_mask, |
|
memory_key_padding_mask, |
|
) |
|
loss = self.criterion( |
|
logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) |
|
) |
|
|
|
if torch.isfinite(loss): |
|
|
|
|
|
self.log( |
|
"val_loss", |
|
loss, |
|
on_step=False, |
|
on_epoch=True, |
|
prog_bar=True, |
|
logger=True, |
|
sync_dist=True, |
|
batch_size=src.size(0), |
|
) |
|
else: |
|
logging.warning( |
|
f"Non-finite loss encountered during validation step {batch_idx}: {loss.item()}." |
|
) |
|
|
|
|
|
|
|
|
|
except RuntimeError as e: |
|
|
|
logging.error(f"Runtime error during validation step {batch_idx}: {e}") |
|
if "CUDA out of memory" in str(e): |
|
logging.warning( |
|
f"CUDA OOM error during validation step {batch_idx} with shape src: {src.shape}, tgt: {tgt.shape}. Skipping batch." |
|
) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
else: |
|
logging.error(f"Shapes - src: {src.shape}, tgt: {tgt.shape}") |
|
|
|
|
|
return None |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.AdamW( |
|
self.parameters(), |
|
lr=self.hparams.learning_rate, |
|
weight_decay=self.hparams.weight_decay, |
|
) |
|
|
|
|
|
|
|
|
|
try: |
|
from transformers import get_linear_schedule_with_warmup |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
num_training_steps = self.trainer.estimated_stepping_batches |
|
logging.info( |
|
f"Estimated stepping batches for LR schedule: {num_training_steps}" |
|
) |
|
if num_training_steps is None or num_training_steps <= 0: |
|
logging.warning( |
|
"Could not estimate stepping batches, using fallback for LR schedule." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
num_training_steps = 1_000_000 |
|
except AttributeError: |
|
logging.warning( |
|
"self.trainer not available yet in configure_optimizers. Using fallback step count for LR schedule." |
|
) |
|
num_training_steps = 1_000_000 |
|
|
|
|
|
num_warmup_steps = int(0.05 * num_training_steps) |
|
logging.info( |
|
f"LR Scheduler: Total steps ~{num_training_steps}, Warmup steps: {num_warmup_steps}" |
|
) |
|
|
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=num_training_steps, |
|
) |
|
|
|
lr_scheduler_config = { |
|
"scheduler": scheduler, |
|
"interval": "step", |
|
"frequency": 1, |
|
"name": "linear_warmup_decay_lr", |
|
} |
|
logging.info("Using Linear Warmup/Decay LR Scheduler.") |
|
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} |
|
|
|
except ImportError: |
|
logging.warning( |
|
"'transformers' library not found. Cannot create linear warmup scheduler. Using constant LR." |
|
) |
|
return optimizer |
|
except Exception as e: |
|
logging.error( |
|
f"Error setting up LR scheduler: {e}. Using constant LR.", exc_info=True |
|
) |
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
def greedy_decode( |
|
model: pl.LightningModule, |
|
src: torch.Tensor, |
|
src_padding_mask: torch.Tensor, |
|
max_len: int, |
|
sos_idx: int, |
|
eos_idx: int, |
|
device: torch.device, |
|
) -> torch.Tensor: |
|
"""Performs greedy decoding using the LightningModule's model.""" |
|
|
|
transformer_model = model.model |
|
|
|
try: |
|
with torch.no_grad(): |
|
|
|
memory = transformer_model.encode( |
|
src, src_padding_mask |
|
) |
|
memory = memory.to(device) |
|
|
|
memory_key_padding_mask = src_padding_mask.to(memory.device) |
|
|
|
ys = ( |
|
torch.ones(1, 1).fill_(sos_idx).type(torch.long).to(device) |
|
) |
|
|
|
for i in range(max_len - 1): |
|
tgt_seq_len = ys.shape[1] |
|
|
|
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to( |
|
device |
|
) |
|
|
|
tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool).to( |
|
device |
|
) |
|
|
|
|
|
out = transformer_model.decode( |
|
ys, memory, tgt_mask, tgt_padding_mask, memory_key_padding_mask |
|
) |
|
|
|
|
|
|
|
last_token_logits = transformer_model.generator( |
|
out[:, -1, :] |
|
) |
|
prob = last_token_logits |
|
_, next_word = torch.max(prob, dim=1) |
|
next_word = next_word.item() |
|
|
|
|
|
ys = torch.cat( |
|
[ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1 |
|
) |
|
|
|
|
|
if next_word == eos_idx: |
|
break |
|
|
|
return ys[:, 1:] |
|
|
|
except RuntimeError as e: |
|
logging.error(f"Runtime error during greedy decode: {e}") |
|
if "CUDA out of memory" in str(e): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
return torch.tensor([[]], dtype=torch.long, device=device) |
|
|
|
|
|
def translate( |
|
model: pl.LightningModule, |
|
src_sentence: str, |
|
smiles_tokenizer, |
|
iupac_tokenizer, |
|
device: torch.device, |
|
max_len: int, |
|
sos_idx: int, |
|
eos_idx: int, |
|
pad_idx: int, |
|
) -> str: |
|
"""Translates a single SMILES string using the LightningModule.""" |
|
model.eval() |
|
|
|
try: |
|
src_encoded = smiles_tokenizer.encode(src_sentence) |
|
if not src_encoded or len(src_encoded.ids) == 0: |
|
logging.warning(f"Encoding failed for SMILES: {src_sentence}") |
|
return "[Encoding Error]" |
|
|
|
src_ids = src_encoded.ids[:max_len] |
|
if not src_ids: |
|
logging.warning( |
|
f"Source sequence empty after truncation for SMILES: {src_sentence}" |
|
) |
|
return "[Encoding Error - Empty Src]" |
|
|
|
except Exception as e: |
|
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}") |
|
return "[Encoding Error]" |
|
|
|
|
|
src = ( |
|
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) |
|
) |
|
|
|
|
|
|
|
|
|
src_padding_mask = src == pad_idx |
|
|
|
|
|
tgt_tokens_tensor = greedy_decode( |
|
model=model, |
|
src=src, |
|
src_padding_mask=src_padding_mask, |
|
max_len=max_len, |
|
sos_idx=sos_idx, |
|
eos_idx=eos_idx, |
|
device=device, |
|
) |
|
|
|
|
|
if tgt_tokens_tensor.numel() > 0: |
|
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist() |
|
try: |
|
|
|
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True) |
|
return translation |
|
except Exception as e: |
|
logging.error(f"Error decoding target tokens {tgt_tokens}: {e}") |
|
return "[Decoding Error]" |
|
else: |
|
|
|
|
|
return "[Decoding Error - Empty Output]" |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
pl.seed_everything(RANDOM_SEED, workers=True) |
|
|
|
|
|
os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
logging.info(f"Loading and splitting data from {INPUT_CSV_FILE}...") |
|
|
|
try: |
|
|
|
df = pd.read_csv(INPUT_CSV_FILE, dtype={"SMILES": str, "Systematic": str}) |
|
logging.info(f"Initial rows loaded: {len(df)}") |
|
if "SMILES" not in df.columns: |
|
raise ValueError("CSV must contain 'SMILES' column.") |
|
if "Systematic" not in df.columns: |
|
raise ValueError("CSV must contain 'Systematic' (IUPAC name) column.") |
|
df.rename(columns={"Systematic": "IUPAC"}, inplace=True) |
|
|
|
initial_rows = len(df) |
|
df.dropna(subset=["SMILES", "IUPAC"], inplace=True) |
|
rows_after_na = len(df) |
|
if initial_rows > rows_after_na: |
|
logging.info( |
|
f"Dropped {initial_rows - rows_after_na} rows with missing values." |
|
) |
|
|
|
df = df[df["SMILES"].str.strip().astype(bool)] |
|
df = df[df["IUPAC"].str.strip().astype(bool)] |
|
df["SMILES"] = df["SMILES"].str.strip() |
|
df["IUPAC"] = df["IUPAC"].str.strip() |
|
rows_after_empty = len(df) |
|
if rows_after_na > rows_after_empty: |
|
logging.info( |
|
f"Dropped {rows_after_na - rows_after_empty} rows with empty strings after stripping." |
|
) |
|
|
|
smiles_data = df["SMILES"].tolist() |
|
iupac_data = df["IUPAC"].tolist() |
|
logging.info(f"Loaded {len(smiles_data)} valid pairs from CSV.") |
|
del df |
|
gc.collect() |
|
|
|
if len(smiles_data) < 10: |
|
raise ValueError( |
|
f"Not enough valid data ({len(smiles_data)}) for split. Need at least 10." |
|
) |
|
|
|
train_smi, val_smi, train_iupac, val_iupac = train_test_split( |
|
smiles_data, |
|
iupac_data, |
|
test_size=VALIDATION_SPLIT, |
|
random_state=RANDOM_SEED, |
|
) |
|
logging.info(f"Split: {len(train_smi)} train, {len(val_smi)} validation.") |
|
del smiles_data, iupac_data |
|
gc.collect() |
|
|
|
logging.info("Writing split data to files...") |
|
with open(TRAIN_SMILES_FILE, "w", encoding="utf-8") as f: |
|
f.write("\n".join(train_smi)) |
|
with open(TRAIN_IUPAC_FILE, "w", encoding="utf-8") as f: |
|
f.write("\n".join(train_iupac)) |
|
with open(VAL_SMILES_FILE, "w", encoding="utf-8") as f: |
|
f.write("\n".join(val_smi)) |
|
with open(VAL_IUPAC_FILE, "w", encoding="utf-8") as f: |
|
f.write("\n".join(val_iupac)) |
|
logging.info( |
|
f"Split files written: {TRAIN_SMILES_FILE}, {TRAIN_IUPAC_FILE}, {VAL_SMILES_FILE}, {VAL_IUPAC_FILE}" |
|
) |
|
del train_smi, val_smi, train_iupac, val_iupac |
|
gc.collect() |
|
|
|
except FileNotFoundError: |
|
logging.error(f"Fatal error: Input CSV file not found at {INPUT_CSV_FILE}") |
|
exit(1) |
|
except ValueError as ve: |
|
logging.error(f"Fatal error during data preparation: {ve}") |
|
exit(1) |
|
except Exception as e: |
|
logging.error(f"Fatal error during data preparation: {e}", exc_info=True) |
|
exit(1) |
|
|
|
|
|
|
|
logging.info("Initializing Tokenizers...") |
|
|
|
if not os.path.exists(TRAIN_SMILES_FILE) or not os.path.exists(TRAIN_IUPAC_FILE): |
|
logging.error( |
|
f"Training files ({TRAIN_SMILES_FILE}, {TRAIN_IUPAC_FILE}) not found. Cannot train tokenizers." |
|
) |
|
exit(1) |
|
|
|
smiles_tokenizer = get_smiles_tokenizer( |
|
train_files=[TRAIN_SMILES_FILE], |
|
vocab_size=SRC_VOCAB_SIZE_ESTIMATE, |
|
tokenizer_path=SMILES_TOKENIZER_FILE, |
|
) |
|
iupac_tokenizer = get_iupac_tokenizer( |
|
train_files=[TRAIN_IUPAC_FILE], |
|
vocab_size=TGT_VOCAB_SIZE_ESTIMATE, |
|
tokenizer_path=IUPAC_TOKENIZER_FILE, |
|
) |
|
|
|
ACTUAL_SRC_VOCAB_SIZE = smiles_tokenizer.get_vocab_size() |
|
ACTUAL_TGT_VOCAB_SIZE = iupac_tokenizer.get_vocab_size() |
|
logging.info(f"Actual SMILES Vocab Size: {ACTUAL_SRC_VOCAB_SIZE}") |
|
logging.info(f"Actual IUPAC Vocab Size: {ACTUAL_TGT_VOCAB_SIZE}") |
|
|
|
hparams["actual_src_vocab_size"] = ACTUAL_SRC_VOCAB_SIZE |
|
hparams["actual_tgt_vocab_size"] = ACTUAL_TGT_VOCAB_SIZE |
|
|
|
|
|
|
|
if WANDB_ENTITY is None: |
|
logging.warning( |
|
"WANDB_ENTITY not set. WandB will log to your default entity. Set WANDB_ENTITY='your_username_or_team' to specify." |
|
) |
|
|
|
wandb_logger = WandbLogger( |
|
project=WANDB_PROJECT, |
|
entity=WANDB_ENTITY, |
|
name=WANDB_RUN_NAME, |
|
config=hparams, |
|
|
|
|
|
) |
|
|
|
|
|
logging.info("Creating Datasets and DataLoaders...") |
|
try: |
|
train_dataset = SmilesIupacDataset(TRAIN_SMILES_FILE, TRAIN_IUPAC_FILE) |
|
val_dataset = SmilesIupacDataset(VAL_SMILES_FILE, VAL_IUPAC_FILE) |
|
if len(train_dataset) == 0 or len(val_dataset) == 0: |
|
logging.error( |
|
"Training or validation dataset is empty. Check data splitting and file content." |
|
) |
|
exit(1) |
|
except Exception as e: |
|
logging.error(f"Error creating Datasets: {e}", exc_info=True) |
|
exit(1) |
|
|
|
|
|
def collate_fn_partial(batch): |
|
return collate_fn( |
|
batch, |
|
smiles_tokenizer, |
|
iupac_tokenizer, |
|
PAD_IDX, |
|
SOS_IDX, |
|
EOS_IDX, |
|
hparams["max_len"], |
|
) |
|
|
|
|
|
persistent_workers = NUM_WORKERS > 0 and STRATEGY == "ddp" |
|
|
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=BATCH_SIZE_PER_GPU, |
|
shuffle=True, |
|
collate_fn=collate_fn_partial, |
|
num_workers=NUM_WORKERS, |
|
pin_memory=True, |
|
persistent_workers=persistent_workers, |
|
drop_last=True, |
|
) |
|
val_dataloader = DataLoader( |
|
val_dataset, |
|
batch_size=BATCH_SIZE_PER_GPU, |
|
shuffle=False, |
|
collate_fn=collate_fn_partial, |
|
num_workers=NUM_WORKERS, |
|
pin_memory=True, |
|
persistent_workers=persistent_workers, |
|
drop_last=False, |
|
) |
|
|
|
|
|
logging.info("Initializing Lightning Module...") |
|
|
|
model = SmilesIupacLitModule( |
|
src_vocab_size=ACTUAL_SRC_VOCAB_SIZE, |
|
tgt_vocab_size=ACTUAL_TGT_VOCAB_SIZE, |
|
hparams_dict=hparams, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=CHECKPOINT_DIR, |
|
filename=BEST_MODEL_FILENAME + "-{epoch:02d}-{val_loss:.4f}", |
|
save_top_k=1, |
|
verbose=True, |
|
monitor="val_loss", |
|
mode="min", |
|
save_last=True, |
|
) |
|
early_stopping_callback = EarlyStopping( |
|
monitor="val_loss", |
|
patience=PATIENCE, |
|
verbose=True, |
|
mode="min", |
|
) |
|
|
|
|
|
logging.info( |
|
f"Initializing PyTorch Lightning Trainer (GPUs={DEVICES}, Strategy='{STRATEGY}', Precision='{PRECISION}')..." |
|
) |
|
trainer = pl.Trainer( |
|
accelerator=ACCELERATOR, |
|
devices=DEVICES, |
|
strategy=STRATEGY, |
|
precision=PRECISION, |
|
max_epochs=NUM_EPOCHS, |
|
logger=wandb_logger, |
|
callbacks=[checkpoint_callback, early_stopping_callback], |
|
gradient_clip_val=GRAD_CLIP_NORM, |
|
accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES, |
|
log_every_n_steps=50, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
logging.info( |
|
f"Starting training with Effective Batch Size: {hparams['effective_batch_size']}..." |
|
) |
|
start_time = time.time() |
|
try: |
|
trainer.fit(model, train_dataloader, val_dataloader) |
|
training_duration = time.time() - start_time |
|
logging.info( |
|
f"Training finished in {training_duration / 3600:.2f} hours ({training_duration:.2f} seconds)." |
|
) |
|
|
|
|
|
best_path = checkpoint_callback.best_model_path |
|
best_score = checkpoint_callback.best_model_score |
|
if best_score is not None: |
|
logging.info( |
|
f"Best model checkpoint saved at: {best_path} with val_loss: {best_score.item():.4f}" |
|
) |
|
|
|
wandb_logger.experiment.summary["best_val_loss"] = best_score.item() |
|
wandb_logger.experiment.summary["best_model_path"] = best_path |
|
else: |
|
logging.warning( |
|
"Could not retrieve best model score from checkpoint callback." |
|
) |
|
|
|
except Exception as e: |
|
logging.error(f"Fatal error during training: {e}", exc_info=True) |
|
|
|
if wandb.run is not None: |
|
wandb.finish(exit_code=1) |
|
exit(1) |
|
|
|
|
|
best_model_path_to_load = checkpoint_callback.best_model_path |
|
logging.info( |
|
f"\nLoading best model from {best_model_path_to_load} for translation examples..." |
|
) |
|
final_model = None |
|
if best_model_path_to_load and os.path.exists(best_model_path_to_load): |
|
try: |
|
|
|
|
|
final_model = SmilesIupacLitModule.load_from_checkpoint( |
|
best_model_path_to_load, |
|
|
|
|
|
src_vocab_size=ACTUAL_SRC_VOCAB_SIZE, |
|
tgt_vocab_size=ACTUAL_TGT_VOCAB_SIZE, |
|
hparams_dict=hparams, |
|
) |
|
|
|
inference_device = torch.device( |
|
f"{ACCELERATOR}:0" |
|
if ACCELERATOR == "gpu" and torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
final_model = final_model.to(inference_device) |
|
final_model.eval() |
|
final_model.freeze() |
|
logging.info( |
|
f"Best model loaded successfully to {inference_device} for final translation." |
|
) |
|
except Exception as e: |
|
logging.error( |
|
f"Error loading saved model from {best_model_path_to_load}: {e}", |
|
exc_info=True, |
|
) |
|
final_model = None |
|
else: |
|
logging.error( |
|
f"Error: Best model checkpoint path not found or invalid: '{best_model_path_to_load}'. Cannot perform final translation." |
|
) |
|
|
|
|
|
if final_model: |
|
logging.info("\n--- Example Translations (using validation data) ---") |
|
num_examples = 20 |
|
try: |
|
|
|
val_smi_examples = [] |
|
val_iupac_examples = [] |
|
if os.path.exists(VAL_SMILES_FILE) and os.path.exists(VAL_IUPAC_FILE): |
|
with ( |
|
open(VAL_SMILES_FILE, "r", encoding="utf-8") as f_smi, |
|
open(VAL_IUPAC_FILE, "r", encoding="utf-8") as f_iupac, |
|
): |
|
for i, (smi_line, iupac_line) in enumerate(zip(f_smi, f_iupac)): |
|
if i >= num_examples: |
|
break |
|
val_smi_examples.append(smi_line.strip()) |
|
val_iupac_examples.append(iupac_line.strip()) |
|
else: |
|
logging.warning( |
|
f"Validation files ({VAL_SMILES_FILE}, {VAL_IUPAC_FILE}) not found. Cannot show examples." |
|
) |
|
|
|
if len(val_smi_examples) > 0: |
|
print("\n" + "=" * 40) |
|
print( |
|
f"Example Translations (First {len(val_smi_examples)} Validation Samples)" |
|
) |
|
print("=" * 40) |
|
|
|
inference_device = next(final_model.parameters()).device |
|
translation_examples = [] |
|
for i in range(len(val_smi_examples)): |
|
smi = val_smi_examples[i] |
|
true_iupac = val_iupac_examples[i] |
|
predicted_iupac = translate( |
|
model=final_model, |
|
src_sentence=smi, |
|
smiles_tokenizer=smiles_tokenizer, |
|
iupac_tokenizer=iupac_tokenizer, |
|
device=inference_device, |
|
max_len=hparams["max_len"], |
|
sos_idx=SOS_IDX, |
|
eos_idx=EOS_IDX, |
|
pad_idx=PAD_IDX, |
|
) |
|
print(f"\nExample {i + 1}:") |
|
print(f" SMILES: {smi}") |
|
print(f" True IUPAC: {true_iupac}") |
|
print(f" Predicted IUPAC: {predicted_iupac}") |
|
print("-" * 30) |
|
|
|
translation_examples.append([smi, true_iupac, predicted_iupac]) |
|
|
|
print("=" * 40 + "\n") |
|
|
|
|
|
try: |
|
columns = ["SMILES", "True IUPAC", "Predicted IUPAC"] |
|
wandb_table = wandb.Table( |
|
data=translation_examples, columns=columns |
|
) |
|
wandb_logger.experiment.log( |
|
{"validation_translations": wandb_table} |
|
) |
|
logging.info("Logged translation examples to WandB Table.") |
|
except Exception as wb_err: |
|
logging.error( |
|
f"Failed to log translation examples to WandB: {wb_err}" |
|
) |
|
|
|
else: |
|
logging.warning("Could not load validation samples for examples.") |
|
except Exception as e: |
|
logging.error(f"Error during example translation phase: {e}", exc_info=True) |
|
else: |
|
logging.warning( |
|
"Skipping final translation examples as the best model could not be loaded." |
|
) |
|
|
|
|
|
if wandb.run is not None: |
|
wandb.finish() |
|
logging.info("WandB run finished.") |
|
else: |
|
logging.info("No active WandB run to finish.") |
|
|
|
logging.info("Script finished.") |
|
|