smi2iupac / app.py
AdrianM0's picture
Upload folder using huggingface_hub
59543a5 verified
raw
history blame
23.7 kB
# app.py
import gradio as gr
import torch
import torch.nn.functional as F # <--- Added import
import pytorch_lightning as pl # <--- Added import (needed for type hints, model access)
import os
import json
import logging
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import gc # For garbage collection on potential OOM
import math # Needed for PositionalEncoding if moved here (or keep in enhanced_trainer)
# --- Configuration ---
MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator"
CHECKPOINT_FILENAME = "last.ckpt"
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
CONFIG_FILENAME = "config.json"
# --- End Configuration ---
# --- Logging ---
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# --- Load Helper Code (Only Model Definition Needed) ---
try:
# We only need the LightningModule definition and the mask function now
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
logging.info("Successfully imported from enhanced_trainer.py.")
# We will define beam_search_decode and translate locally in this file
# REMOVED: from test_ckpt import beam_search_decode, translate
except ImportError as e:
logging.error(
f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
)
gr.Error(
f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
)
exit()
except Exception as e:
logging.error(
f"An unexpected error occurred during helper code import: {e}", exc_info=True
)
gr.Error(
f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
)
exit()
# --- Global Variables (Load Model Once) ---
model: pl.LightningModule | None = None # Added type hint
smiles_tokenizer: Tokenizer | None = None
iupac_tokenizer: Tokenizer | None = None
device: torch.device | None = None
config: dict | None = None
# --- Beam Search Decoding Logic (Moved from test_ckpt.py) ---
def beam_search_decode(
model: pl.LightningModule,
src: torch.Tensor,
src_padding_mask: torch.Tensor,
max_len: int,
sos_idx: int,
eos_idx: int,
pad_idx: int, # Needed for padding mask check if src has padding
device: torch.device,
beam_width: int = 5,
n_best: int = 5, # Number of top sequences to return
length_penalty: float = 0.6, # Alpha for length normalization (0=no penalty, 1=full penalty)
) -> list[torch.Tensor]:
"""
Performs beam search decoding using the LightningModule's model.
(Code copied and pasted from test_ckpt.py)
"""
# Ensure model is in eval mode (redundant if called after model.eval(), but safe)
model.eval()
transformer_model = model.model # Access the underlying Seq2SeqTransformer
n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences
try:
with torch.no_grad():
# --- Encode Source ---
memory = transformer_model.encode(
src, src_padding_mask
) # [1, src_len, emb_size]
memory = memory.to(device)
# Ensure memory_key_padding_mask is also on the correct device for decode
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
# --- Initialize Beams ---
initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
sos_idx
) # [1, 1]
initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
active_beams = [(initial_beam_seq, initial_beam_score)]
finished_beams = []
# --- Decoding Loop ---
for step in range(max_len - 1):
if not active_beams:
break
potential_next_beams = []
for current_seq, current_score in active_beams:
if current_seq[0, -1].item() == eos_idx:
finished_beams.append((current_seq, current_score))
continue
tgt_input = current_seq # [1, current_len]
tgt_seq_len = tgt_input.shape[1]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
device
) # [curr_len, curr_len]
tgt_padding_mask = torch.zeros(
tgt_input.shape, dtype=torch.bool, device=device
) # [1, curr_len]
decoder_output = transformer_model.decode(
tgt=tgt_input,
memory=memory,
tgt_mask=tgt_mask,
tgt_padding_mask=tgt_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # [1, curr_len, emb_size]
next_token_logits = transformer_model.generator(
decoder_output[:, -1, :]
) # [1, tgt_vocab_size]
log_probs = F.log_softmax(
next_token_logits, dim=-1
) # [1, tgt_vocab_size]
topk_log_probs, topk_indices = torch.topk(
log_probs + current_score, beam_width, dim=-1
)
for i in range(beam_width):
next_token_id = topk_indices[0, i].item()
next_score = topk_log_probs[0, i].reshape(
1
) # Keep as tensor [1]
next_token_tensor = torch.tensor(
[[next_token_id]], dtype=torch.long, device=device
) # [1, 1]
new_seq = torch.cat(
[current_seq, next_token_tensor], dim=1
) # [1, current_len + 1]
potential_next_beams.append((new_seq, next_score))
potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
active_beams = []
added_count = 0
for seq, score in potential_next_beams:
is_finished = seq[0, -1].item() == eos_idx
if is_finished:
finished_beams.append((seq, score))
elif added_count < beam_width:
active_beams.append((seq, score))
added_count += 1
elif added_count >= beam_width:
break
finished_beams.extend(active_beams)
# Apply length penalty and sort
# Handle potential division by zero if sequence length is 1 (or 0?)
def get_score(beam_tuple):
seq, score = beam_tuple
seq_len = seq.shape[1]
if length_penalty == 0.0 or seq_len <= 1:
return score.item()
else:
# Ensure seq_len is float for pow
return score.item() / (float(seq_len) ** length_penalty)
finished_beams.sort(key=get_score, reverse=True) # Higher score is better
top_sequences = [
seq[:, 1:] for seq, score in finished_beams[:n_best]
] # seq shape [1, len] -> [1, len-1]
return top_sequences
except RuntimeError as e:
logging.error(f"Runtime error during beam search decode: {e}")
if "CUDA out of memory" in str(e):
gc.collect()
torch.cuda.empty_cache()
return [] # Return empty list on error
except Exception as e:
logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
return []
# --- Translation Function (Moved from test_ckpt.py) ---
def translate(
model: pl.LightningModule,
src_sentence: str,
smiles_tokenizer: Tokenizer,
iupac_tokenizer: Tokenizer,
device: torch.device,
max_len: int,
sos_idx: int,
eos_idx: int,
pad_idx: int,
beam_width: int = 5,
n_best: int = 5,
length_penalty: float = 0.6,
) -> list[str]:
"""
Translates a single SMILES string using beam search.
(Code copied and pasted from test_ckpt.py)
"""
model.eval() # Ensure model is in eval mode
translations = []
# --- Tokenize Source ---
try:
src_encoded = smiles_tokenizer.encode(src_sentence)
if not src_encoded or not src_encoded.ids:
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
return ["[Encoding Error]"] * n_best
src_ids = src_encoded.ids[:max_len] # Truncate source
if not src_ids:
logging.warning(f"Source empty after truncation: {src_sentence}")
return ["[Encoding Error - Empty Src]"] * n_best
except Exception as e:
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
return ["[Encoding Error]"] * n_best
# --- Prepare Input Tensor and Mask ---
src = (
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
) # [1, src_len]
src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
# --- Perform Beam Search Decoding ---
# Calls the beam_search_decode function defined above in this file
tgt_tokens_list = beam_search_decode(
model=model,
src=src,
src_padding_mask=src_padding_mask,
max_len=max_len,
sos_idx=sos_idx,
eos_idx=eos_idx,
pad_idx=pad_idx,
device=device,
beam_width=beam_width,
n_best=n_best,
length_penalty=length_penalty,
) # Returns list of tensors
# --- Decode Generated Tokens ---
if not tgt_tokens_list:
logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
return ["[Decoding Error - Empty Output]"] * n_best
for tgt_tokens_tensor in tgt_tokens_list:
if tgt_tokens_tensor.numel() > 0:
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
try:
translation = iupac_tokenizer.decode(
tgt_tokens, skip_special_tokens=True
)
translations.append(translation)
except Exception as e:
logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
translations.append("[Decoding Error]")
else:
translations.append("[Decoding Error - Empty Tensor]")
# Pad with error messages if fewer than n_best results were generated
while len(translations) < n_best:
translations.append("[Decoding Error - Fewer Results]")
return translations
# --- Model/Tokenizer Loading Function (Unchanged) ---
def load_model_and_tokenizers():
"""Loads tokenizers, config, and model from Hugging Face Hub."""
global model, smiles_tokenizer, iupac_tokenizer, device, config
if model is not None: # Already loaded
logging.info("Model and tokenizers already loaded.")
return
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
try:
device = torch.device("cpu")
logging.info(f"Using device: {device}")
# Download files from HF Hub
logging.info("Downloading files from Hugging Face Hub...")
try:
checkpoint_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME
)
smiles_tokenizer_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME
)
iupac_tokenizer_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME
)
config_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME
)
logging.info("Files downloaded successfully.")
except Exception as e:
logging.error(
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
exc_info=True,
)
raise gr.Error(
f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}"
)
# Load config
logging.info("Loading configuration...")
try:
with open(config_path, "r") as f:
config = json.load(f)
logging.info("Configuration loaded.")
# --- Validate essential config keys ---
required_keys = [
"src_vocab_size",
"tgt_vocab_size",
"emb_size",
"nhead",
"ffn_hid_dim",
"num_encoder_layers",
"num_decoder_layers",
"dropout",
"max_len",
"bos_token_id",
"eos_token_id",
"pad_token_id",
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ValueError(
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}"
)
# --- End Validation ---
except FileNotFoundError:
logging.error(
f"Config file not found locally after download attempt: {config_path}"
)
raise gr.Error(
f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo."
)
except json.JSONDecodeError as e:
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
raise gr.Error(
f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}"
)
except ValueError as e:
logging.error(f"Config validation error: {e}")
raise gr.Error(f"Config Error: {e}")
# Load tokenizers
logging.info("Loading tokenizers...")
try:
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
logging.info("Tokenizers loaded.")
# --- Validate Tokenizer Special Tokens ---
# Add more robust checks if necessary
if (
smiles_tokenizer.token_to_id("<pad>") != config["pad_token_id"]
or smiles_tokenizer.token_to_id("<unk>") is None
):
logging.warning(
"SMILES tokenizer special tokens might not match config or are missing."
)
if (
iupac_tokenizer.token_to_id("<pad>") != config["pad_token_id"]
or iupac_tokenizer.token_to_id("<sos>") != config["bos_token_id"]
or iupac_tokenizer.token_to_id("<eos>") != config["eos_token_id"]
or iupac_tokenizer.token_to_id("<unk>") is None
):
logging.warning(
"IUPAC tokenizer special tokens might not match config or are missing."
)
# --- End Validation ---
except Exception as e:
logging.error(
f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}",
exc_info=True,
)
raise gr.Error(
f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}"
)
# Load model
logging.info("Loading model from checkpoint...")
try:
model = SmilesIupacLitModule.load_from_checkpoint(
checkpoint_path,
src_vocab_size=config["src_vocab_size"],
tgt_vocab_size=config["tgt_vocab_size"],
map_location=device,
hparams_dict=config,
strict=False,
device="cpu",
)
model.to(device)
model.eval()
model.freeze()
logging.info(
"Model loaded successfully, set to eval mode, frozen, and moved to device."
)
except FileNotFoundError:
logging.error(
f"Checkpoint file not found locally after download attempt: {checkpoint_path}"
)
raise gr.Error(
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
)
except Exception as e:
logging.error(
f"Error loading model from checkpoint {checkpoint_path}: {e}",
exc_info=True,
)
if "memory" in str(e).lower():
gc.collect()
if device == torch.device("cuda"):
torch.cuda.empty_cache()
raise gr.Error(
f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}"
)
except gr.Error:
raise
except Exception as e:
logging.error(
f"Unexpected error during model/tokenizer loading: {e}", exc_info=True
)
raise gr.Error(
f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}"
)
# --- Inference Function for Gradio (Unchanged, calls local translate) ---
def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
"""
Performs SMILES to IUPAC translation using the loaded model and beam search.
"""
global model, smiles_tokenizer, iupac_tokenizer, device, config
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs."
# Ensure n_best is int for range, default to 1 if conversion fails early
try:
n_best_int = int(n_best)
except:
n_best_int = 1
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
if not smiles_string or not smiles_string.strip():
error_msg = "Error: Please enter a valid SMILES string."
try:
n_best_int = int(n_best)
except:
n_best_int = 1
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
smiles_input = smiles_string.strip()
try:
beam_width = int(beam_width)
n_best = int(n_best)
length_penalty = float(length_penalty)
except ValueError as e:
error_msg = f"Error: Invalid input parameter type ({e})."
return f"1. {error_msg}" # Cannot determine n_best here
logging.info(
f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})"
)
try:
# Calls the translate function defined *above in this file*
predicted_names = translate(
model=model,
src_sentence=smiles_input,
smiles_tokenizer=smiles_tokenizer,
iupac_tokenizer=iupac_tokenizer,
device=device,
max_len=config["max_len"],
sos_idx=config["bos_token_id"],
eos_idx=config["eos_token_id"],
pad_idx=config["pad_token_id"],
beam_width=beam_width,
n_best=n_best,
length_penalty=length_penalty,
)
logging.info(f"Predictions returned: {predicted_names}")
if not predicted_names:
output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated."
else:
output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
output_text += "\n".join(
[f"{i + 1}. {name}" for i, name in enumerate(predicted_names)]
)
return output_text
except RuntimeError as e:
logging.error(f"Runtime error during translation: {e}", exc_info=True)
error_msg = f"Runtime Error during translation: {e}"
if "memory" in str(e).lower():
gc.collect()
if device == torch.device("cuda"):
torch.cuda.empty_cache()
error_msg += " (Potential OOM)"
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
except Exception as e:
logging.error(f"Unexpected error during translation: {e}", exc_info=True)
error_msg = f"Unexpected Error during translation: {e}"
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
# --- Load Model on App Start (Unchanged) ---
try:
load_model_and_tokenizers()
except gr.Error:
pass # Error already raised for Gradio UI
except Exception as e:
logging.error(
f"Critical error during initial model loading sequence: {e}", exc_info=True
)
gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
# --- Create Gradio Interface (Unchanged) ---
title = "SMILES to IUPAC Name Translator"
description = f"""
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model and beam search decoding.
Model repository: <a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>{MODEL_REPO_ID}</a>.
Adjust beam search parameters below. Higher beam width explores more possibilities but is slower. Length penalty influences the preference for shorter/longer names.
"""
examples = [
["CCO", 5, 3, 0.6],
["C1=CC=CC=C1", 5, 3, 0.6],
["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
[
"CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C",
5,
1,
0.6,
], # Complex example
["INVALID_SMILES", 5, 1, 0.6],
]
smiles_input = gr.Textbox(
label="SMILES String",
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
lines=1,
)
beam_width_input = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Beam Width (k)",
info="Number of sequences to keep at each decoding step (higher = more exploration, slower).",
)
n_best_input = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Results (n_best)",
info="How many top-scoring sequences to return (must be <= Beam Width).",
)
length_penalty_input = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.6,
step=0.1,
label="Length Penalty (alpha)",
info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty.",
)
output_text = gr.Textbox(
label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
)
iface = gr.Interface(
fn=predict_iupac,
inputs=[smiles_input, beam_width_input, n_best_input, length_penalty_input],
outputs=output_text,
title=title,
description=description,
examples=examples,
allow_flagging="never",
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results.",
)
# --- Launch the App (Unchanged) ---
if __name__ == "__main__":
iface.launch(share=True)