smi2iupac / app.py
AdrianM0's picture
Update app.py
b3f7b39 verified
raw
history blame
22.9 kB
# app.py
import gradio as gr
import torch
# import torch.nn.functional as F # No longer needed for greedy decode directly
import pytorch_lightning as pl
import os
import json
import logging
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import gc
from rdkit.Chem import CanonSmiles
# --- Configuration ---
MODEL_REPO_ID = (
"AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
)
CHECKPOINT_FILENAME = "last.ckpt"
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
CONFIG_FILENAME = "config.json"
# --- End Configuration ---
# --- Logging ---
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
try:
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
logging.info("Successfully imported from enhanced_trainer.py.")
except ImportError as e:
logging.error(
f"Failed to import helper code from enhanced_trainer.py: {e}. "
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
)
raise gr.Error(
f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
)
except Exception as e:
logging.error(
f"An unexpected error occurred during helper code import: {e}", exc_info=True
)
raise gr.Error(
f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
)
# --- Global Variables (Load Model Once) ---
model: pl.LightningModule | None = None
smiles_tokenizer: Tokenizer | None = None
iupac_tokenizer: Tokenizer | None = None
device: torch.device | None = None
config: dict | None = None
# --- Greedy Decoding Logic (Locally defined) ---
def greedy_decode(
model: pl.LightningModule,
src: torch.Tensor,
src_padding_mask: torch.Tensor,
max_len: int,
sos_idx: int,
eos_idx: int,
device: torch.device,
) -> torch.Tensor:
"""
Performs greedy decoding using the LightningModule's model.
"""
model.eval() # Ensure model is in evaluation mode
transformer_model = model.model # Access the underlying Seq2SeqTransformer
try:
with torch.no_grad():
# --- Encode Source ---
memory = transformer_model.encode(
src, src_padding_mask
) # [1, src_len, emb_size]
memory = memory.to(device)
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
# --- Initialize Target Sequence ---
# Start with the SOS token
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
sos_idx
) # [1, 1]
# --- Decoding Loop ---
for _ in range(max_len - 1): # Max length limit
tgt_seq_len = ys.shape[1]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
device
) # [curr_len, curr_len]
# No padding in target during generation yet
tgt_padding_mask = torch.zeros(
ys.shape, dtype=torch.bool, device=device
) # [1, curr_len]
# Decode one step
decoder_output = transformer_model.decode(
tgt=ys,
memory=memory,
tgt_mask=tgt_mask,
tgt_padding_mask=tgt_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # [1, curr_len, emb_size]
# Get logits for the *next* token prediction
next_token_logits = transformer_model.generator(
decoder_output[
:, -1, :
] # Use output corresponding to the last input token
) # [1, tgt_vocab_size]
# Find the most likely next token (greedy choice)
# prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
# _, next_word_id_tensor = torch.max(prob, dim=1)
next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
next_word_id = next_word_id_tensor.item()
# Append the chosen token to the sequence
ys = torch.cat(
[
ys,
torch.ones(1, 1, dtype=torch.long, device=device).fill_(
next_word_id
),
],
dim=1,
) # [1, current_len + 1]
# Stop if EOS token is generated
if next_word_id == eos_idx:
break
# Return the generated sequence (excluding the initial SOS token)
return ys[:, 1:] # Shape [1, generated_len]
except RuntimeError as e:
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
if "CUDA out of memory" in str(e) and device.type == "cuda":
gc.collect()
torch.cuda.empty_cache()
return torch.empty(
(1, 0), dtype=torch.long, device=device
) # Return empty tensor on error
except Exception as e:
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
return torch.empty((1, 0), dtype=torch.long, device=device)
# --- Translation Function (Using Greedy Decode) ---
def translate(
model: pl.LightningModule,
src_sentence: str,
smiles_tokenizer: Tokenizer,
iupac_tokenizer: Tokenizer,
device: torch.device,
max_len: int,
sos_idx: int,
eos_idx: int,
pad_idx: int,
) -> str: # Returns a single string
"""
Translates a single SMILES string using greedy decoding.
"""
model.eval() # Ensure model is in eval mode
# --- Tokenize Source ---
try:
# Ensure tokenizer has truncation/padding configured if needed, or handle manually
smiles_tokenizer.enable_truncation(
max_length=max_len
) # Use max_len for source truncation too
src_encoded = smiles_tokenizer.encode(src_sentence)
if not src_encoded or not src_encoded.ids:
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
return "[Encoding Error]"
# Use the truncated IDs directly
src_ids = src_encoded.ids
except Exception as e:
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
return "[Encoding Error]"
# --- Prepare Input Tensor and Mask ---
src = (
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
) # [1, src_len]
# Create padding mask (True where it's a pad token, should be all False here unless tokenizer pads)
src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
# --- Perform Greedy Decoding ---
# Calls the greedy_decode function defined *above in this file*
# Note: max_len for generation should come from config if it dictates output length
generation_max_len = config.get(
"max_len", 256
) # Use config max_len for output limit
tgt_tokens_tensor = greedy_decode(
model=model,
src=src,
src_padding_mask=src_padding_mask,
max_len=generation_max_len, # Use generation limit
sos_idx=sos_idx,
eos_idx=eos_idx,
# pad_idx=pad_idx, # Not needed by greedy_decode internal loop
device=device,
) # Returns a single tensor [1, generated_len]
# --- Decode Generated Tokens ---
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
logging.warning(
f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
)
return "[Decoding Error - Empty Output]"
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
try:
# Decode using the target tokenizer, skipping special tokens
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
return translation
except Exception as e:
logging.error(
f"Error decoding target tokens {tgt_tokens}: {e}",
exc_info=True,
)
return "[Decoding Error]"
# --- Model/Tokenizer Loading Function (Unchanged) ---
def load_model_and_tokenizers():
"""Loads tokenizers, config, and model from Hugging Face Hub."""
global model, smiles_tokenizer, iupac_tokenizer, device, config
if model is not None: # Already loaded
logging.info("Model and tokenizers already loaded.")
return
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
try:
# Determine device
if torch.cuda.is_available():
logging.warning(
"CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
)
device = torch.device("cpu") # Uncomment if GPU is intended
# device = torch.device("cuda")
# logging.info("CUDA available, using GPU.")
else:
device = torch.device("cpu")
logging.info("CUDA not available, using CPU.")
# Download files
logging.info("Downloading files from Hugging Face Hub...")
try:
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
os.makedirs(cache_dir, exist_ok=True)
logging.info(f"Using cache directory: {cache_dir}")
checkpoint_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
)
smiles_tokenizer_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=SMILES_TOKENIZER_FILENAME,
cache_dir=cache_dir,
)
iupac_tokenizer_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=IUPAC_TOKENIZER_FILENAME,
cache_dir=cache_dir,
)
config_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
)
logging.info("Files downloaded successfully.")
except Exception as e:
logging.error(
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
exc_info=True,
)
raise gr.Error(
f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}"
)
# Load config
logging.info("Loading configuration...")
try:
with open(config_path, "r") as f:
config = json.load(f)
logging.info("Configuration loaded.")
# --- Validate essential config keys ---
required_keys = [
"src_vocab_size", # Use the key saved in config
"tgt_vocab_size", # Use the key saved in config
"emb_size",
"nhead",
"ffn_hid_dim",
"num_encoder_layers",
"num_decoder_layers",
"dropout",
"max_len",
"pad_token_id",
"bos_token_id",
"eos_token_id",
]
# Remap if needed (example shown, adjust if your keys differ)
config_key_mapping = {
"src_vocab_size": config.get(
"src_vocab_size", config.get("src_vocab_size")
),
"tgt_vocab_size": config.get(
"tgt_vocab_size", config.get("tgt_vocab_size")
),
# Add other mappings if necessary
}
config.update(config_key_mapping)
missing_keys = [key for key in required_keys if config.get(key) is None]
if missing_keys:
raise ValueError(
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
f"Ensure these were saved in the hyperparameters during training."
)
logging.info(
f"Using config values: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
)
except FileNotFoundError:
logging.error(f"Config file not found: {config_path}")
raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
except json.JSONDecodeError as e:
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
raise gr.Error(
f"Config Error: Could not parse '{CONFIG_FILENAME}'. Error: {e}"
)
except ValueError as e:
logging.error(f"Config validation error: {e}")
raise gr.Error(f"Config Error: {e}")
except Exception as e:
logging.error(f"Unexpected error loading config: {e}", exc_info=True)
raise gr.Error(f"Config Error: Unexpected error. Check logs. Error: {e}")
# Load tokenizers
logging.info("Loading tokenizers...")
try:
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
logging.info("Tokenizers loaded.")
# --- Optional: Validate Tokenizer Special Tokens Against Config ---
# (Keep validation as before, it's still useful)
pad_token = "<pad>"
sos_token = "<sos>"
eos_token = "<eos>"
unk_token = "<unk>"
issues = []
# ... (keep the validation checks as in the original code) ...
if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
issues.append(f"SMILES PAD ID mismatch")
if smiles_tokenizer.token_to_id(unk_token) is None:
issues.append("SMILES UNK token not found")
if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
issues.append(f"IUPAC PAD ID mismatch")
if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
issues.append(f"IUPAC SOS ID mismatch")
if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
issues.append(f"IUPAC EOS ID mismatch")
if iupac_tokenizer.token_to_id(unk_token) is None:
issues.append("IUPAC UNK token not found")
if issues:
logging.warning("Tokenizer validation issues: " + "; ".join(issues))
except Exception as e:
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
raise gr.Error(
f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
)
# Load model
logging.info("Loading model from checkpoint...")
try:
# Use the vocab sizes and hparams from the loaded config
model = SmilesIupacLitModule.load_from_checkpoint(
checkpoint_path,
**config, # Pass loaded config directly as keyword args
map_location=device,
devices=1,
strict=True, # Start strict, set to False if encountering key errors
)
model.to(device)
model.eval()
model.freeze()
logging.info(
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
)
except FileNotFoundError:
logging.error(f"Checkpoint file not found: {checkpoint_path}")
raise gr.Error(
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
)
except Exception as e:
logging.error(
f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
)
if "size mismatch" in str(e):
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
logging.error(error_detail)
raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
elif "memory" in str(e).lower():
logging.warning("Potential OOM error during model loading.")
gc.collect()
torch.cuda.empty_cache() if device.type == "cuda" else None
raise gr.Error(
f"Model Error: OOM loading model. Check Space resources. Error: {e}"
)
else:
raise gr.Error(
f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
)
except gr.Error:
raise
except Exception as e:
logging.error(f"Unexpected error during loading: {e}", exc_info=True)
raise gr.Error(
f"Initialization Error: Unexpected error. Check logs. Error: {e}"
)
# --- Inference Function for Gradio (Simplified) ---
def predict_iupac(smiles_string):
"""
Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
"""
try:
smiles_string = CanonSmiles(smiles_string)
except Exception as e:
logging.error(f"Error during SMILES canonicalization: {e}", exc_info=True)
return f"Error: {e}"
global model, smiles_tokenizer, iupac_tokenizer, device, config
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
logging.error(error_msg)
return f"Error: {error_msg}" # Return single error string
if not smiles_string or not smiles_string.strip():
error_msg = "Error: Please enter a valid SMILES string."
return f"Error: {error_msg}" # Return single error string
smiles_input = smiles_string.strip()
try:
# --- Call the core translation logic (greedy) ---
sos_idx = config["bos_token_id"]
eos_idx = config["eos_token_id"]
pad_idx = config["pad_token_id"]
gen_max_len = config["max_len"]
predicted_name = translate( # Returns a single string now
model=model,
src_sentence=smiles_input,
smiles_tokenizer=smiles_tokenizer,
iupac_tokenizer=iupac_tokenizer,
device=device,
max_len=gen_max_len,
sos_idx=sos_idx,
eos_idx=eos_idx,
pad_idx=pad_idx,
)
logging.info(f"Prediction returned: {predicted_name}")
# --- Format Output ---
if "[Error]" in predicted_name: # Check for error messages from translate
output_text = (
f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
)
elif not predicted_name:
output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
else:
output_text = (
f"Input SMILES: {smiles_input}\n\n"
f"Predicted IUPAC Name (Greedy Decode):\n"
f"{predicted_name}"
)
return output_text
except RuntimeError as e:
logging.error(f"Runtime error during translation: {e}", exc_info=True)
return f"Error: {error_msg}" # Return single error string
except Exception as e:
logging.error(f"Unexpected error during translation: {e}", exc_info=True)
error_msg = f"Unexpected Error during translation: {e}"
return f"Error: {error_msg}" # Return single error string
# --- Load Model on App Start ---
try:
load_model_and_tokenizers()
except gr.Error as ge:
logging.error(f"Gradio Initialization Error: {ge}")
pass # Allow Gradio to potentially start with an error message
except Exception as e:
logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
# Optionally raise gr.Error here too
# --- Create Gradio Interface (Simplified) ---
title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
description = f"""
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
Translation uses **greedy decoding** (picks the most likely next word at each step).
**Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
"""
# Replace your Interface code with this:
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
smiles_input = gr.Textbox(
label="SMILES String",
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
lines=1,
)
# Add a button for manual triggering
submit_btn = gr.Button("Translate")
with gr.Column():
output_text = gr.Textbox(
label="Predicted IUPAC Name",
lines=3,
show_copy_button=True,
)
# Connect the events properly
submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
# If you still want change event (auto-translation as user types)
smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
# Output component
output_text = gr.Textbox(
label="Predicted IUPAC Name",
lines=3,
show_copy_button=True, # Reduced lines slightly
)
# Create the interface instance
iface = gr.Interface(
fn=predict_iupac, # The function to call
inputs=smiles_input, # Input component directly
outputs=output_text, # Output component
title=title,
description=description,
allow_flagging="never",
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
)
# --- Launch the App ---
if __name__ == "__main__":
iface.launch()