|
|
|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
import pytorch_lightning as pl |
|
import os |
|
import json |
|
import logging |
|
from tokenizers import Tokenizer |
|
from huggingface_hub import hf_hub_download |
|
import gc |
|
import math |
|
|
|
|
|
MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator" |
|
CHECKPOINT_FILENAME = "last.ckpt" |
|
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json" |
|
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json" |
|
CONFIG_FILENAME = "config.json" |
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
try: |
|
|
|
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask |
|
|
|
logging.info("Successfully imported from enhanced_trainer.py.") |
|
|
|
|
|
|
|
|
|
except ImportError as e: |
|
logging.error( |
|
f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'." |
|
) |
|
gr.Error( |
|
f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}" |
|
) |
|
exit() |
|
except Exception as e: |
|
logging.error( |
|
f"An unexpected error occurred during helper code import: {e}", exc_info=True |
|
) |
|
gr.Error( |
|
f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}" |
|
) |
|
exit() |
|
|
|
|
|
model: pl.LightningModule | None = None |
|
smiles_tokenizer: Tokenizer | None = None |
|
iupac_tokenizer: Tokenizer | None = None |
|
device: torch.device | None = None |
|
config: dict | None = None |
|
|
|
|
|
|
|
|
|
def beam_search_decode( |
|
model: pl.LightningModule, |
|
src: torch.Tensor, |
|
src_padding_mask: torch.Tensor, |
|
max_len: int, |
|
sos_idx: int, |
|
eos_idx: int, |
|
pad_idx: int, |
|
device: torch.device, |
|
beam_width: int = 5, |
|
n_best: int = 5, |
|
length_penalty: float = 0.6, |
|
) -> list[torch.Tensor]: |
|
""" |
|
Performs beam search decoding using the LightningModule's model. |
|
(Code copied and pasted from test_ckpt.py) |
|
""" |
|
|
|
model.eval() |
|
transformer_model = model.model |
|
n_best = min(n_best, beam_width) |
|
|
|
try: |
|
with torch.no_grad(): |
|
|
|
memory = transformer_model.encode( |
|
src, src_padding_mask |
|
) |
|
memory = memory.to(device) |
|
|
|
memory_key_padding_mask = src_padding_mask.to(memory.device) |
|
|
|
|
|
initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_( |
|
sos_idx |
|
) |
|
initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) |
|
active_beams = [(initial_beam_seq, initial_beam_score)] |
|
finished_beams = [] |
|
|
|
|
|
for step in range(max_len - 1): |
|
if not active_beams: |
|
break |
|
|
|
potential_next_beams = [] |
|
for current_seq, current_score in active_beams: |
|
if current_seq[0, -1].item() == eos_idx: |
|
finished_beams.append((current_seq, current_score)) |
|
continue |
|
|
|
tgt_input = current_seq |
|
tgt_seq_len = tgt_input.shape[1] |
|
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to( |
|
device |
|
) |
|
tgt_padding_mask = torch.zeros( |
|
tgt_input.shape, dtype=torch.bool, device=device |
|
) |
|
|
|
decoder_output = transformer_model.decode( |
|
tgt=tgt_input, |
|
memory=memory, |
|
tgt_mask=tgt_mask, |
|
tgt_padding_mask=tgt_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
) |
|
|
|
next_token_logits = transformer_model.generator( |
|
decoder_output[:, -1, :] |
|
) |
|
log_probs = F.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
topk_log_probs, topk_indices = torch.topk( |
|
log_probs + current_score, beam_width, dim=-1 |
|
) |
|
|
|
for i in range(beam_width): |
|
next_token_id = topk_indices[0, i].item() |
|
next_score = topk_log_probs[0, i].reshape( |
|
1 |
|
) |
|
next_token_tensor = torch.tensor( |
|
[[next_token_id]], dtype=torch.long, device=device |
|
) |
|
new_seq = torch.cat( |
|
[current_seq, next_token_tensor], dim=1 |
|
) |
|
potential_next_beams.append((new_seq, next_score)) |
|
|
|
potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True) |
|
|
|
active_beams = [] |
|
added_count = 0 |
|
for seq, score in potential_next_beams: |
|
is_finished = seq[0, -1].item() == eos_idx |
|
if is_finished: |
|
finished_beams.append((seq, score)) |
|
elif added_count < beam_width: |
|
active_beams.append((seq, score)) |
|
added_count += 1 |
|
elif added_count >= beam_width: |
|
break |
|
|
|
finished_beams.extend(active_beams) |
|
|
|
|
|
|
|
def get_score(beam_tuple): |
|
seq, score = beam_tuple |
|
seq_len = seq.shape[1] |
|
if length_penalty == 0.0 or seq_len <= 1: |
|
return score.item() |
|
else: |
|
|
|
return score.item() / (float(seq_len) ** length_penalty) |
|
|
|
finished_beams.sort(key=get_score, reverse=True) |
|
|
|
top_sequences = [ |
|
seq[:, 1:] for seq, score in finished_beams[:n_best] |
|
] |
|
return top_sequences |
|
|
|
except RuntimeError as e: |
|
logging.error(f"Runtime error during beam search decode: {e}") |
|
if "CUDA out of memory" in str(e): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return [] |
|
except Exception as e: |
|
logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True) |
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
def translate( |
|
model: pl.LightningModule, |
|
src_sentence: str, |
|
smiles_tokenizer: Tokenizer, |
|
iupac_tokenizer: Tokenizer, |
|
device: torch.device, |
|
max_len: int, |
|
sos_idx: int, |
|
eos_idx: int, |
|
pad_idx: int, |
|
beam_width: int = 5, |
|
n_best: int = 5, |
|
length_penalty: float = 0.6, |
|
) -> list[str]: |
|
""" |
|
Translates a single SMILES string using beam search. |
|
(Code copied and pasted from test_ckpt.py) |
|
""" |
|
model.eval() |
|
translations = [] |
|
|
|
|
|
try: |
|
src_encoded = smiles_tokenizer.encode(src_sentence) |
|
if not src_encoded or not src_encoded.ids: |
|
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}") |
|
return ["[Encoding Error]"] * n_best |
|
src_ids = src_encoded.ids[:max_len] |
|
if not src_ids: |
|
logging.warning(f"Source empty after truncation: {src_sentence}") |
|
return ["[Encoding Error - Empty Src]"] * n_best |
|
except Exception as e: |
|
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}") |
|
return ["[Encoding Error]"] * n_best |
|
|
|
|
|
src = ( |
|
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) |
|
) |
|
src_padding_mask = (src == pad_idx).to(device) |
|
|
|
|
|
|
|
tgt_tokens_list = beam_search_decode( |
|
model=model, |
|
src=src, |
|
src_padding_mask=src_padding_mask, |
|
max_len=max_len, |
|
sos_idx=sos_idx, |
|
eos_idx=eos_idx, |
|
pad_idx=pad_idx, |
|
device=device, |
|
beam_width=beam_width, |
|
n_best=n_best, |
|
length_penalty=length_penalty, |
|
) |
|
|
|
|
|
if not tgt_tokens_list: |
|
logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}") |
|
return ["[Decoding Error - Empty Output]"] * n_best |
|
|
|
for tgt_tokens_tensor in tgt_tokens_list: |
|
if tgt_tokens_tensor.numel() > 0: |
|
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist() |
|
try: |
|
translation = iupac_tokenizer.decode( |
|
tgt_tokens, skip_special_tokens=True |
|
) |
|
translations.append(translation) |
|
except Exception as e: |
|
logging.error(f"Error decoding target tokens {tgt_tokens}: {e}") |
|
translations.append("[Decoding Error]") |
|
else: |
|
translations.append("[Decoding Error - Empty Tensor]") |
|
|
|
|
|
while len(translations) < n_best: |
|
translations.append("[Decoding Error - Fewer Results]") |
|
|
|
return translations |
|
|
|
|
|
|
|
def load_model_and_tokenizers(): |
|
"""Loads tokenizers, config, and model from Hugging Face Hub.""" |
|
global model, smiles_tokenizer, iupac_tokenizer, device, config |
|
if model is not None: |
|
logging.info("Model and tokenizers already loaded.") |
|
return |
|
|
|
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...") |
|
try: |
|
device = torch.device("cpu") |
|
logging.info(f"Using device: {device}") |
|
|
|
|
|
logging.info("Downloading files from Hugging Face Hub...") |
|
try: |
|
checkpoint_path = hf_hub_download( |
|
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME |
|
) |
|
smiles_tokenizer_path = hf_hub_download( |
|
repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME |
|
) |
|
iupac_tokenizer_path = hf_hub_download( |
|
repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME |
|
) |
|
config_path = hf_hub_download( |
|
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME |
|
) |
|
logging.info("Files downloaded successfully.") |
|
except Exception as e: |
|
logging.error( |
|
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", |
|
exc_info=True, |
|
) |
|
raise gr.Error( |
|
f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}" |
|
) |
|
|
|
|
|
logging.info("Loading configuration...") |
|
try: |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
logging.info("Configuration loaded.") |
|
|
|
required_keys = [ |
|
"src_vocab_size", |
|
"tgt_vocab_size", |
|
"emb_size", |
|
"nhead", |
|
"ffn_hid_dim", |
|
"num_encoder_layers", |
|
"num_decoder_layers", |
|
"dropout", |
|
"max_len", |
|
"bos_token_id", |
|
"eos_token_id", |
|
"pad_token_id", |
|
] |
|
missing_keys = [key for key in required_keys if key not in config] |
|
if missing_keys: |
|
raise ValueError( |
|
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}" |
|
) |
|
|
|
except FileNotFoundError: |
|
logging.error( |
|
f"Config file not found locally after download attempt: {config_path}" |
|
) |
|
raise gr.Error( |
|
f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo." |
|
) |
|
except json.JSONDecodeError as e: |
|
logging.error(f"Error decoding JSON from config file {config_path}: {e}") |
|
raise gr.Error( |
|
f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}" |
|
) |
|
except ValueError as e: |
|
logging.error(f"Config validation error: {e}") |
|
raise gr.Error(f"Config Error: {e}") |
|
|
|
|
|
logging.info("Loading tokenizers...") |
|
try: |
|
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path) |
|
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path) |
|
logging.info("Tokenizers loaded.") |
|
|
|
|
|
if ( |
|
smiles_tokenizer.token_to_id("<pad>") != config["pad_token_id"] |
|
or smiles_tokenizer.token_to_id("<unk>") is None |
|
): |
|
logging.warning( |
|
"SMILES tokenizer special tokens might not match config or are missing." |
|
) |
|
if ( |
|
iupac_tokenizer.token_to_id("<pad>") != config["pad_token_id"] |
|
or iupac_tokenizer.token_to_id("<sos>") != config["bos_token_id"] |
|
or iupac_tokenizer.token_to_id("<eos>") != config["eos_token_id"] |
|
or iupac_tokenizer.token_to_id("<unk>") is None |
|
): |
|
logging.warning( |
|
"IUPAC tokenizer special tokens might not match config or are missing." |
|
) |
|
|
|
except Exception as e: |
|
logging.error( |
|
f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}", |
|
exc_info=True, |
|
) |
|
raise gr.Error( |
|
f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}" |
|
) |
|
|
|
|
|
logging.info("Loading model from checkpoint...") |
|
try: |
|
model = SmilesIupacLitModule.load_from_checkpoint( |
|
checkpoint_path, |
|
src_vocab_size=config["src_vocab_size"], |
|
tgt_vocab_size=config["tgt_vocab_size"], |
|
map_location=device, |
|
hparams_dict=config, |
|
strict=False, |
|
device="cpu", |
|
) |
|
model.to(device) |
|
model.eval() |
|
model.freeze() |
|
logging.info( |
|
"Model loaded successfully, set to eval mode, frozen, and moved to device." |
|
) |
|
|
|
except FileNotFoundError: |
|
logging.error( |
|
f"Checkpoint file not found locally after download attempt: {checkpoint_path}" |
|
) |
|
raise gr.Error( |
|
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found." |
|
) |
|
except Exception as e: |
|
logging.error( |
|
f"Error loading model from checkpoint {checkpoint_path}: {e}", |
|
exc_info=True, |
|
) |
|
if "memory" in str(e).lower(): |
|
gc.collect() |
|
if device == torch.device("cuda"): |
|
torch.cuda.empty_cache() |
|
raise gr.Error( |
|
f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}" |
|
) |
|
|
|
except gr.Error: |
|
raise |
|
except Exception as e: |
|
logging.error( |
|
f"Unexpected error during model/tokenizer loading: {e}", exc_info=True |
|
) |
|
raise gr.Error( |
|
f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}" |
|
) |
|
|
|
|
|
|
|
def predict_iupac(smiles_string, beam_width, n_best, length_penalty): |
|
""" |
|
Performs SMILES to IUPAC translation using the loaded model and beam search. |
|
""" |
|
global model, smiles_tokenizer, iupac_tokenizer, device, config |
|
|
|
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]): |
|
error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs." |
|
|
|
try: |
|
n_best_int = int(n_best) |
|
except: |
|
n_best_int = 1 |
|
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)]) |
|
|
|
if not smiles_string or not smiles_string.strip(): |
|
error_msg = "Error: Please enter a valid SMILES string." |
|
try: |
|
n_best_int = int(n_best) |
|
except: |
|
n_best_int = 1 |
|
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)]) |
|
|
|
smiles_input = smiles_string.strip() |
|
try: |
|
beam_width = int(beam_width) |
|
n_best = int(n_best) |
|
length_penalty = float(length_penalty) |
|
except ValueError as e: |
|
error_msg = f"Error: Invalid input parameter type ({e})." |
|
return f"1. {error_msg}" |
|
|
|
logging.info( |
|
f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})" |
|
) |
|
|
|
try: |
|
|
|
predicted_names = translate( |
|
model=model, |
|
src_sentence=smiles_input, |
|
smiles_tokenizer=smiles_tokenizer, |
|
iupac_tokenizer=iupac_tokenizer, |
|
device=device, |
|
max_len=config["max_len"], |
|
sos_idx=config["bos_token_id"], |
|
eos_idx=config["eos_token_id"], |
|
pad_idx=config["pad_token_id"], |
|
beam_width=beam_width, |
|
n_best=n_best, |
|
length_penalty=length_penalty, |
|
) |
|
logging.info(f"Predictions returned: {predicted_names}") |
|
|
|
if not predicted_names: |
|
output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated." |
|
else: |
|
output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n" |
|
output_text += "\n".join( |
|
[f"{i + 1}. {name}" for i, name in enumerate(predicted_names)] |
|
) |
|
|
|
return output_text |
|
|
|
except RuntimeError as e: |
|
logging.error(f"Runtime error during translation: {e}", exc_info=True) |
|
error_msg = f"Runtime Error during translation: {e}" |
|
if "memory" in str(e).lower(): |
|
gc.collect() |
|
if device == torch.device("cuda"): |
|
torch.cuda.empty_cache() |
|
error_msg += " (Potential OOM)" |
|
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)]) |
|
|
|
except Exception as e: |
|
logging.error(f"Unexpected error during translation: {e}", exc_info=True) |
|
error_msg = f"Unexpected Error during translation: {e}" |
|
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)]) |
|
|
|
|
|
|
|
try: |
|
load_model_and_tokenizers() |
|
except gr.Error: |
|
pass |
|
except Exception as e: |
|
logging.error( |
|
f"Critical error during initial model loading sequence: {e}", exc_info=True |
|
) |
|
gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.") |
|
|
|
|
|
|
|
title = "SMILES to IUPAC Name Translator" |
|
description = f""" |
|
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model and beam search decoding. |
|
Model repository: <a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>{MODEL_REPO_ID}</a>. |
|
Adjust beam search parameters below. Higher beam width explores more possibilities but is slower. Length penalty influences the preference for shorter/longer names. |
|
""" |
|
|
|
examples = [ |
|
["CCO", 5, 3, 0.6], |
|
["C1=CC=CC=C1", 5, 3, 0.6], |
|
["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], |
|
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], |
|
[ |
|
"CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C", |
|
5, |
|
1, |
|
0.6, |
|
], |
|
["INVALID_SMILES", 5, 1, 0.6], |
|
] |
|
|
|
smiles_input = gr.Textbox( |
|
label="SMILES String", |
|
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)", |
|
lines=1, |
|
) |
|
beam_width_input = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=5, |
|
step=1, |
|
label="Beam Width (k)", |
|
info="Number of sequences to keep at each decoding step (higher = more exploration, slower).", |
|
) |
|
n_best_input = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=3, |
|
step=1, |
|
label="Number of Results (n_best)", |
|
info="How many top-scoring sequences to return (must be <= Beam Width).", |
|
) |
|
length_penalty_input = gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
value=0.6, |
|
step=0.1, |
|
label="Length Penalty (alpha)", |
|
info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty.", |
|
) |
|
output_text = gr.Textbox( |
|
label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True |
|
) |
|
|
|
iface = gr.Interface( |
|
fn=predict_iupac, |
|
inputs=[smiles_input, beam_width_input, n_best_input, length_penalty_input], |
|
outputs=output_text, |
|
title=title, |
|
description=description, |
|
examples=examples, |
|
allow_flagging="never", |
|
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"), |
|
article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results.", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |
|
|