smi2iupac / app.py
AdrianM0's picture
Update app.py
01fa093 verified
raw
history blame
31.8 kB
# app.py
import gradio as gr
import torch
# import torch.nn.functional as F # No longer needed for greedy decode directly
import pytorch_lightning as pl
import os
import json
import logging
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import gc
try:
from rdkit import Chem
from rdkit import RDLogger # Optional: To suppress RDKit logs
RDLogger.DisableLog('rdApp.*') # Suppress RDKit warnings/errors if desired
except ImportError:
logging.warning("RDKit not found. SMILES canonicalization will be skipped. Install with 'pip install rdkit'")
Chem = None # Set Chem to None if RDKit is not available
# --- Configuration ---
MODEL_REPO_ID = (
"AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
)
CHECKPOINT_FILENAME = "last.ckpt"
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
CONFIG_FILENAME = "config.json"
# --- End Configuration ---
# --- Logging ---
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
try:
# Ensure enhanced_trainer.py is in the root directory of your space repo
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
logging.info("Successfully imported from enhanced_trainer.py.")
except ImportError as e:
logging.error(
f"Failed to import helper code from enhanced_trainer.py: {e}. "
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
)
# We raise gr.Error later during loading if the class isn't found
SmilesIupacLitModule = None
generate_square_subsequent_mask = None
except Exception as e:
logging.error(
f"An unexpected error occurred during helper code import: {e}", exc_info=True
)
SmilesIupacLitModule = None
generate_square_subsequent_mask = None
# --- Global Variables (Load Model Once) ---
model: pl.LightningModule | None = None
smiles_tokenizer: Tokenizer | None = None
iupac_tokenizer: Tokenizer | None = None
device: torch.device | None = None
config: dict | None = None
# --- Greedy Decoding Logic (Locally defined) ---
def greedy_decode(
model: pl.LightningModule,
src: torch.Tensor,
src_padding_mask: torch.Tensor,
max_len: int,
sos_idx: int,
eos_idx: int,
device: torch.device,
) -> torch.Tensor:
"""
Performs greedy decoding using the LightningModule's model.
Assumes model has 'model.encode', 'model.decode', 'model.generator' attributes.
"""
if not hasattr(model, 'model') or not hasattr(model.model, 'encode') or \
not hasattr(model.model, 'decode') or not hasattr(model.model, 'generator'):
logging.error("Model object does not have the expected 'model.encode/decode/generator' structure.")
raise AttributeError("Model structure mismatch for greedy decoding.")
if generate_square_subsequent_mask is None:
logging.error("generate_square_subsequent_mask function not imported.")
raise ImportError("generate_square_subsequent_mask is required for greedy_decode.")
model.eval() # Ensure model is in evaluation mode
transformer_model = model.model # Access the underlying Seq2SeqTransformer
try:
with torch.no_grad():
# --- Encode Source ---
# The mask should be True where the input *is* padding.
memory = transformer_model.encode(
src, src_mask=None # Standard transformer encoder doesn't usually use src_mask
) # [1, src_len, emb_size] if batch_first=True in TransformerEncoder
# If batch_first=False (default): [src_len, 1, emb_size] -> adjust usage below
# Assuming batch_first=False for standard nn.Transformer
memory = memory.to(device)
# Memory key padding mask needs to be [batch_size, src_len] -> [1, src_len]
memory_key_padding_mask = src_padding_mask.to(device) # [1, src_len]
# --- Initialize Target Sequence ---
# Start with the SOS token -> Shape [1, 1] (batch, seq)
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
sos_idx
)
# --- Decoding Loop ---
for i in range(max_len - 1): # Max length limit
# Target processing depends on whether decoder expects batch_first
# Standard nn.TransformerDecoder expects [tgt_len, batch_size, emb_size]
# Standard nn.TransformerDecoder expects tgt_mask [tgt_len, tgt_len]
# Standard nn.TransformerDecoder expects memory_key_padding_mask [batch_size, src_len]
# Standard nn.TransformerDecoder expects tgt_key_padding_mask [batch_size, tgt_len]
tgt_seq_len = ys.shape[1]
# Create causal mask -> [tgt_len, tgt_len]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
device
)
# Target padding mask: False for non-pad tokens -> [1, tgt_len]
tgt_padding_mask = torch.zeros(
ys.shape, dtype=torch.bool, device=device
)
# Prepare target for decoder (assuming batch_first=False expected)
# Input ys is [1, current_len] -> need [current_len, 1]
ys_decoder_input = ys.transpose(0, 1).to(device) # [current_len, 1]
# Decode one step
decoder_output = transformer_model.decode(
tgt=ys_decoder_input, # [current_len, 1]
memory=memory, # [src_len, 1, emb_size]
tgt_mask=tgt_mask, # [current_len, current_len]
tgt_key_padding_mask=tgt_padding_mask, # [1, current_len]
memory_key_padding_mask=memory_key_padding_mask, # [1, src_len]
) # Output shape [current_len, 1, emb_size]
# Get logits for the *next* token prediction
# Use output corresponding to the last input token -> [-1, :, :]
next_token_logits = transformer_model.generator(
decoder_output[-1, :, :] # Shape [1, emb_size]
) # Output shape [1, tgt_vocab_size]
# Find the most likely next token (greedy choice)
next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # Shape [1]
next_word_id = next_word_id_tensor.item()
# Append the chosen token to the sequence (shape [1, 1])
next_word_tensor = torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id)
# Concatenate along the sequence dimension (dim=1)
ys = torch.cat([ys, next_word_tensor], dim=1) # [1, current_len + 1]
# Stop if EOS token is generated
if next_word_id == eos_idx:
break
# Return the generated sequence (excluding the initial SOS token)
# Shape [1, generated_len]
return ys[:, 1:]
except RuntimeError as e:
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
if "CUDA out of memory" in str(e) and device.type == "cuda":
gc.collect()
torch.cuda.empty_cache()
raise RuntimeError("CUDA out of memory during greedy decoding.") # Re-raise specific error
raise e # Re-raise other runtime errors
except Exception as e:
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
raise e # Re-raise
# --- Translation Function (Using Greedy Decode) ---
def translate(
model: pl.LightningModule,
src_sentence: str,
smiles_tokenizer: Tokenizer,
iupac_tokenizer: Tokenizer,
device: torch.device,
max_len_config: int, # Max length from config (used for source truncation & generation limit)
sos_idx: int,
eos_idx: int,
pad_idx: int,
) -> str: # Returns a single string or an error message
"""
Translates a single SMILES string using greedy decoding.
"""
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
return "[Initialization Error: Components not loaded]"
model.eval() # Ensure model is in eval mode
# --- Tokenize Source ---
try:
# Ensure tokenizer has truncation configured
smiles_tokenizer.enable_truncation(max_length=max_len_config)
smiles_tokenizer.enable_padding(pad_id=pad_idx, pad_token="<pad>", length=max_len_config) # Ensure padding for consistent input length if needed by model
src_encoded = smiles_tokenizer.encode(src_sentence)
if not src_encoded or not src_encoded.ids:
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
return "[Encoding Error: Empty result]"
src_ids = src_encoded.ids
# Use attention mask directly for padding mask (1 for real tokens, 0 for padding)
# We need the opposite for PyTorch Transformer (True for padding, False for real)
src_attention_mask = torch.tensor(src_encoded.attention_mask, dtype=torch.long)
src_padding_mask = (src_attention_mask == 0) # True where it's padded
except Exception as e:
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
return f"[Encoding Error: {e}]"
# --- Prepare Input Tensor and Mask ---
# Input tensor shape [1, src_len]
src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
# Padding mask shape [1, src_len]
src_padding_mask = src_padding_mask.unsqueeze(0).to(device)
# --- Perform Greedy Decoding ---
try:
tgt_tokens_tensor = greedy_decode(
model=model,
src=src,
src_padding_mask=src_padding_mask,
max_len=max_len_config, # Use config max_len as generation limit
sos_idx=sos_idx,
eos_idx=eos_idx,
device=device,
) # Returns a single tensor [1, generated_len]
except (RuntimeError, AttributeError, ImportError, Exception) as e:
logging.error(f"Error during greedy_decode call: {e}", exc_info=True)
return f"[Decoding Error: {e}]"
# --- Decode Generated Tokens ---
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
# Check if the source itself was just padding or EOS
if len(src_ids) <= 2 and all(t in [pad_idx, eos_idx, sos_idx] for t in src_ids): # Rough check
logging.warning(f"Input SMILES '{src_sentence}' resulted in very short/empty encoding, leading to empty decode.")
return "[Decoding Warning: Input potentially too short or invalid after tokenization]"
else:
logging.warning(
f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
)
return "[Decoding Error: Empty Output]"
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
try:
# Decode using the target tokenizer, skipping special tokens
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
return translation.strip() # Strip leading/trailing whitespace
except Exception as e:
logging.error(
f"Error decoding target tokens {tgt_tokens}: {e}",
exc_info=True,
)
return "[Decoding Error: Tokenizer failed]"
# --- Model/Tokenizer Loading Function ---
def load_model_and_tokenizers():
"""Loads tokenizers, config, and model from Hugging Face Hub."""
global model, smiles_tokenizer, iupac_tokenizer, device, config, SmilesIupacLitModule, generate_square_subsequent_mask
if model is not None: # Already loaded
logging.info("Model and tokenizers already loaded.")
return
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
# --- Check if helper code loaded ---
if SmilesIupacLitModule is None or generate_square_subsequent_mask is None:
error_msg = f"Initialization Error: Could not load required components from enhanced_trainer.py. Check Space logs and ensure the file exists in the repo root."
logging.error(error_msg)
raise gr.Error(error_msg)
try:
# Determine device
if torch.cuda.is_available():
device = torch.device("cuda")
logging.info("CUDA available, using GPU.")
else:
device = torch.device("cpu")
logging.info("CUDA not available, using CPU.")
# Download files
logging.info("Downloading files from Hugging Face Hub...")
try:
# Define cache directory, default to './hf_cache' if GRADIO_CACHE is not set
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
os.makedirs(cache_dir, exist_ok=True) # Ensure cache dir exists
logging.info(f"Using cache directory: {cache_dir}")
# Download files to the specified cache directory
checkpoint_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir, force_download=False # Avoid re-download if files exist
)
smiles_tokenizer_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir, force_download=False
)
iupac_tokenizer_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir, force_download=False
)
config_path = hf_hub_download(
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir, force_download=False
)
logging.info("Files downloaded (or found in cache) successfully.")
except Exception as e:
logging.error(
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
exc_info=True,
)
raise gr.Error(
f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}"
)
# Load config
logging.info("Loading configuration...")
try:
with open(config_path, "r") as f:
config = json.load(f)
logging.info("Configuration loaded.")
# --- Validate essential config keys ---
required_keys = [
"src_vocab_size",
"tgt_vocab_size",
"emb_size",
"nhead",
"ffn_hid_dim",
"num_encoder_layers",
"num_decoder_layers",
"dropout",
"max_len", # Crucial for tokenization and generation limit
"pad_token_id",
"bos_token_id",
"eos_token_id",
]
# Check for alternative key names if needed (adjust if your config uses different names)
config['src_vocab_size'] = config.get('src_vocab_size', config.get('SRC_VOCAB_SIZE'))
config['tgt_vocab_size'] = config.get('tgt_vocab_size', config.get('TGT_VOCAB_SIZE'))
config['emb_size'] = config.get('emb_size', config.get('EMB_SIZE'))
config['nhead'] = config.get('nhead', config.get('NHEAD'))
config['ffn_hid_dim'] = config.get('ffn_hid_dim', config.get('FFN_HID_DIM'))
config['num_encoder_layers'] = config.get('num_encoder_layers', config.get('NUM_ENCODER_LAYERS'))
config['num_decoder_layers'] = config.get('num_decoder_layers', config.get('NUM_DECODER_LAYERS'))
config['dropout'] = config.get('dropout', config.get('DROPOUT'))
config['max_len'] = config.get('max_len', config.get('MAX_LEN'))
config['pad_token_id'] = config.get('pad_token_id', config.get('PAD_IDX'))
config['bos_token_id'] = config.get('bos_token_id', config.get('BOS_IDX'))
config['eos_token_id'] = config.get('eos_token_id', config.get('EOS_IDX'))
# Add UNK if needed by your model/tokenizer setup
# config['unk_token_id'] = config.get('unk_token_id', config.get('UNK_IDX', 0)) # Default to 0 if missing? Risky.
missing_keys = [key for key in required_keys if config.get(key) is None]
if missing_keys:
raise ValueError(
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
f"Ensure these were saved in the hyperparameters during training."
)
logging.info(
f"Using config: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
)
except FileNotFoundError:
logging.error(f"Config file not found: {config_path}")
raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
except json.JSONDecodeError as e:
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
raise gr.Error(
f"Config Error: Could not parse '{CONFIG_FILENAME}'. Error: {e}"
)
except ValueError as e:
logging.error(f"Config validation error: {e}")
raise gr.Error(f"Config Error: {e}")
except Exception as e:
logging.error(f"Unexpected error loading config: {e}", exc_info=True)
raise gr.Error(f"Config Error: Unexpected error. Check logs. Error: {e}")
# Load tokenizers
logging.info("Loading tokenizers...")
try:
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
# --- Validate Tokenizer Special Tokens Against Config ---
pad_token = "<pad>"
sos_token = "<sos>"
eos_token = "<eos>"
unk_token = "<unk>" # Assuming standard UNK token
issues = []
# SMILES Tokenizer Checks
smiles_pad_id = smiles_tokenizer.token_to_id(pad_token)
smiles_unk_id = smiles_tokenizer.token_to_id(unk_token)
if smiles_pad_id is None or smiles_pad_id != config["pad_token_id"]:
issues.append(f"SMILES PAD ID mismatch (Tokenizer: {smiles_pad_id}, Config: {config['pad_token_id']})")
if smiles_unk_id is None:
issues.append("SMILES UNK token not found in tokenizer")
# IUPAC Tokenizer Checks
iupac_pad_id = iupac_tokenizer.token_to_id(pad_token)
iupac_sos_id = iupac_tokenizer.token_to_id(sos_token)
iupac_eos_id = iupac_tokenizer.token_to_id(eos_token)
iupac_unk_id = iupac_tokenizer.token_to_id(unk_token)
if iupac_pad_id is None or iupac_pad_id != config["pad_token_id"]:
issues.append(f"IUPAC PAD ID mismatch (Tokenizer: {iupac_pad_id}, Config: {config['pad_token_id']})")
if iupac_sos_id is None or iupac_sos_id != config["bos_token_id"]:
issues.append(f"IUPAC SOS ID mismatch (Tokenizer: {iupac_sos_id}, Config: {config['bos_token_id']})")
if iupac_eos_id is None or iupac_eos_id != config["eos_token_id"]:
issues.append(f"IUPAC EOS ID mismatch (Tokenizer: {iupac_eos_id}, Config: {config['eos_token_id']})")
if iupac_unk_id is None:
issues.append("IUPAC UNK token not found in tokenizer")
if issues:
logging.warning("Tokenizer validation issues detected: \n - " + "\n - ".join(issues))
# Decide if this is critical. For inference, SOS/EOS/PAD matches are most important.
# raise gr.Error("Tokenizer Validation Error: Mismatch between config and tokenizer files. Check logs.")
else:
logging.info("Tokenizers loaded and special tokens validated against config.")
except Exception as e:
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
raise gr.Error(
f"Tokenizer Error: Could not load tokenizers. Check logs and file paths. Error: {e}"
)
# Load model
logging.info("Loading model from checkpoint...")
try:
# Instantiate the LightningModule using hyperparameters from config
# Make sure SmilesIupacLitModule's __init__ accepts these keys
model_instance = SmilesIupacLitModule(**config)
# Load the state dict from the checkpoint onto the instance
# Use load_state_dict for more control if load_from_checkpoint causes issues
# state_dict = torch.load(checkpoint_path, map_location=device)['state_dict']
# model_instance.load_state_dict(state_dict, strict=True) # Try strict=False if needed
# Use load_from_checkpoint (simpler if it works)
model = SmilesIupacLitModule.load_from_checkpoint(
checkpoint_path,
map_location=device,
# Pass hparams again ONLY if they are needed by load_from_checkpoint specifically
# and not just by __init__. Usually, instantiating first is cleaner.
**config, # Try removing this if you instantiate first
strict=True # Start strict, set to False ONLY if necessary and you understand why
)
model.to(device)
model.eval()
model.freeze() # Freeze weights for inference
logging.info(
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
)
except FileNotFoundError:
logging.error(f"Checkpoint file not found: {checkpoint_path}")
raise gr.Error(
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found at expected path."
)
except RuntimeError as e: # Catch specific runtime errors like size mismatch, OOM
logging.error(f"Runtime error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
if "size mismatch" in str(e):
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint structure. Or config doesn't match model definition."
logging.error(error_detail)
raise gr.Error(f"Model Load Error: {error_detail} Original error: {e}")
elif "CUDA out of memory" in str(e) or "memory" in str(e).lower():
logging.warning("Potential OOM error during model loading.")
gc.collect()
if device.type == "cuda": torch.cuda.empty_cache()
raise gr.Error(f"Model Load Error: OOM loading model. Check Space resources. Error: {e}")
else:
raise gr.Error(f"Model Load Error: Runtime error. Check logs. Error: {e}")
except Exception as e: # Catch other potential errors during loading
logging.error(
f"Unexpected error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
)
raise gr.Error(
f"Model Load Error: Failed to load checkpoint for unknown reason. Check logs. Error: {e}"
)
except gr.Error as ge: # Catch Gradio-specific errors raised above
raise ge # Re-raise to stop app launch correctly
except Exception as e: # Catch any other unexpected errors during the whole process
logging.error(f"Unexpected error during loading process: {e}", exc_info=True)
raise gr.Error(
f"Initialization Error: Unexpected failure during setup. Check logs. Error: {e}"
)
# --- Inference Function for Gradio ---
def predict_iupac(smiles_string):
"""
Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
Handles input validation, canonicalization, translation, and output formatting.
"""
global model, smiles_tokenizer, iupac_tokenizer, device, config
# --- Check Initialization ---
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
error_msg = "Error: Model or tokenizers not loaded. App initialization failed. Check Space logs."
logging.error(error_msg)
# Return the error directly in the output box
return error_msg
# --- Input Validation ---
if not smiles_string or not smiles_string.strip():
return "Error: Please enter a valid SMILES string."
smiles_input_raw = smiles_string.strip()
# --- Canonicalize SMILES (Optional but Recommended) ---
smiles_input_canon = smiles_input_raw
if Chem:
try:
mol = Chem.MolFromSmiles(smiles_input_raw)
if mol:
smiles_input_canon = Chem.MolToSmiles(mol, canonical=True)
logging.info(f"Canonicalized SMILES: {smiles_input_raw} -> {smiles_input_canon}")
else:
# RDKit couldn't parse it, proceed with raw input but warn
logging.warning(f"Could not parse SMILES '{smiles_input_raw}' with RDKit. Using raw input.")
# Optionally return an error here if strict parsing is needed
# return f"Error: Invalid SMILES string '{smiles_input_raw}' according to RDKit."
except Exception as e:
logging.error(f"Error during RDKit canonicalization for '{smiles_input_raw}': {e}", exc_info=True)
# Proceed with raw input, maybe add note to output
# return f"Error: RDKit processing failed: {e}" # Option to fail hard
# --- Translation ---
try:
sos_idx = config["bos_token_id"]
eos_idx = config["eos_token_id"]
pad_idx = config["pad_token_id"]
gen_max_len = config["max_len"] # Use max_len from config
predicted_name = translate(
model=model,
src_sentence=smiles_input_canon, # Use canonicalized SMILES
smiles_tokenizer=smiles_tokenizer,
iupac_tokenizer=iupac_tokenizer,
device=device,
max_len_config=gen_max_len,
sos_idx=sos_idx,
eos_idx=eos_idx,
pad_idx=pad_idx,
)
logging.info(f"SMILES: '{smiles_input_canon}', Prediction: '{predicted_name}'")
# --- Format Output ---
# Check if translate returned an error message
if predicted_name.startswith("[") and predicted_name.endswith("]"):
# Assume it's an error/warning message from translate()
output_text = (
f"Input SMILES: {smiles_input_canon}\n"
f"(Raw Input: {smiles_input_raw})\n\n" # Show raw if canonicalization happened
f"Prediction Failed: {predicted_name}"
)
elif not predicted_name: # Handle empty string case
output_text = (
f"Input SMILES: {smiles_input_canon}\n"
f"(Raw Input: {smiles_input_raw})\n\n"
f"Prediction: [No name generated]"
)
else:
output_text = (
f"Input SMILES: {smiles_input_canon}\n"
f"(Raw Input: {smiles_input_raw})\n\n"
f"Predicted IUPAC Name (Greedy Decode):\n"
f"{predicted_name}"
)
# Remove the "(Raw Input...)" line if canonicalization didn't change the input
if smiles_input_raw == smiles_input_canon:
output_text = output_text.replace(f"(Raw Input: {smiles_input_raw})\n", "")
return output_text.strip()
except Exception as e:
# Catch-all for unexpected errors during the prediction process
logging.error(f"Unexpected error during prediction for '{smiles_input_canon}': {e}", exc_info=True)
error_msg = f"Error: An unexpected error occurred during translation: {e}"
return error_msg
# --- Load Model on App Start ---
# Wrap in try/except to allow Gradio UI to potentially display an error
model_load_error = None
try:
load_model_and_tokenizers()
except gr.Error as ge:
logging.error(f"Gradio Initialization Error during load: {ge}")
model_load_error = str(ge) # Store error message
except Exception as e:
logging.critical(f"CRITICAL error during initial model loading: {e}", exc_info=True)
model_load_error = f"Critical Error: {e}. Check Space logs."
# --- Create Gradio Interface ---
title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
description = f"""
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
Translation uses **greedy decoding** (picks the most likely next word at each step).
"""
if model_load_error:
description += f"\n\n**WARNING: Failed to load model or components.**\nReason: {model_load_error}\nFunctionality will be limited."
elif device:
description += f"\n**Note:** Model loaded on **{str(device).upper()}**. Check `config.json` in the repo for model details."
else:
description += f"\n**Note:** Device information unavailable (loading might have failed)."
# Use gr.Blocks for layout
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1): # Input column takes less space
smiles_input = gr.Textbox(
label="SMILES String",
placeholder="Enter SMILES string (e.g., CCO or c1ccccc1)",
lines=2, # Slightly more lines for longer SMILES
)
submit_btn = gr.Button("Translate", variant="primary")
with gr.Column(scale=2): # Output column takes more space
output_text = gr.Textbox(
label="Result",
lines=5, # More lines for formatted output
show_copy_button=True,
interactive=False, # Output box is not for user input
)
# Define examples
gr.Examples(
examples=[
"CCO",
"C1=CC=C(C=C1)C(=O)O", # Benzoic acid
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # Ibuprofen
"INVALID_SMILES",
"ClC(Cl)(Cl)C1=CC=C(C=C1)C(C2=CC=C(Cl)C=C2)C(Cl)(Cl)Cl", # DDT
],
inputs=smiles_input,
outputs=output_text,
fn=predict_iupac, # Function to run for examples
cache_examples=False, # Re-run examples each time if needed, True might speed up demo loading
)
# Connect the button click and input change events
submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text, api_name="translate_smiles")
# Optionally trigger on text change (can be slow/resource intensive)
# smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
# --- Launch the App ---
if __name__ == "__main__":
# Set share=True to get a public link (useful for testing)
# Set debug=True for more detailed Gradio errors during development
iface.launch(share=False, debug=False)