smi2iupac / app.py
AdrianM0's picture
Update app.py
d11f5bc verified
# app.py
import gradio as gr
import torch
import torch.nn.functional as F # Needed for log_softmax in beam search
import pytorch_lightning as pl
import os
import json
import logging
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import gc
from rdkit.Chem import CanonSmiles, MolFromSmiles # Added MolFromSmiles for validation
import spaces
import heapq # For beam search priority queue
import math # For log probabilities
# --- Configuration ---
MODEL_REPO_ID = (
"AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
)
CHECKPOINT_FILENAME = "last.ckpt"
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
CONFIG_FILENAME = "config.json"
# --- End Configuration ---
# --- Logging ---
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
try:
# Ensure enhanced_trainer.py is present in the repository root
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
logging.info("Successfully imported from enhanced_trainer.py.")
except ImportError as e:
logging.error(
f"Failed to import helper code from enhanced_trainer.py: {e}. "
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
)
raise gr.Error(
f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
)
except Exception as e:
logging.error(
f"An unexpected error occurred during helper code import: {e}", exc_info=True
)
raise gr.Error(
f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
)
# --- Global Variables (Load Model Once) ---
model: pl.LightningModule | None = None
smiles_tokenizer: Tokenizer | None = None
iupac_tokenizer: Tokenizer | None = None
device: torch.device | None = None
config: dict | None = None
# --- Greedy Decoding Logic (Unchanged) ---
def greedy_decode(
model: pl.LightningModule,
src: torch.Tensor,
src_padding_mask: torch.Tensor,
max_len: int,
sos_idx: int,
eos_idx: int,
device: torch.device,
) -> torch.Tensor:
"""
Performs greedy decoding using the LightningModule's model.
Returns a tensor of shape [1, sequence_length].
"""
model.eval()
transformer_model = model.model
try:
with torch.no_grad():
memory = transformer_model.encode(src, src_padding_mask)
memory = memory.to(device)
memory_key_padding_mask = src_padding_mask.to(memory.device)
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx)
for _ in range(max_len - 1):
tgt_seq_len = ys.shape[1]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device)
tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device)
decoder_output = transformer_model.decode(
tgt=ys,
memory=memory,
tgt_mask=tgt_mask,
tgt_padding_mask=tgt_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
next_token_logits = transformer_model.generator(decoder_output[:, -1, :])
next_word_id = torch.argmax(next_token_logits, dim=1).item()
ys = torch.cat(
[
ys,
torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id),
],
dim=1,
)
if next_word_id == eos_idx:
break
return ys[:, 1:] # Exclude SOS
except RuntimeError as e:
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
if "CUDA out of memory" in str(e) and device.type == "cuda":
gc.collect()
torch.cuda.empty_cache()
return torch.empty((1, 0), dtype=torch.long, device=device)
except Exception as e:
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
return torch.empty((1, 0), dtype=torch.long, device=device)
# --- Beam Search Decoding Logic ---
def beam_search_decode(
model: pl.LightningModule,
src: torch.Tensor,
src_padding_mask: torch.Tensor,
max_len: int,
sos_idx: int,
eos_idx: int,
pad_idx: int, # Needed for padding shorter beams if batching
device: torch.device,
beam_width: int,
num_return_sequences: int = 1,
length_penalty_alpha: float = 0.6, # Add length penalty
) -> list[tuple[torch.Tensor, float]]:
"""
Performs beam search decoding.
Returns a list of tuples: (sequence_tensor [1, seq_len], score)
"""
model.eval()
transformer_model = model.model
num_return_sequences = min(beam_width, num_return_sequences) # Cannot return more than beam width
try:
with torch.no_grad():
# --- Encode Source (Once) ---
memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size]
memory = memory.to(device)
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
# --- Initialize Beams ---
# Each beam: (sequence_tensor [1, current_len], score (log_prob))
initial_beam = (
torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx),
0.0, # Initial score (log probability)
)
beams = [initial_beam]
finished_hypotheses = [] # Store finished sequences: (score, sequence_tensor)
# --- Decoding Loop ---
for step in range(max_len - 1):
if not beams: # Stop if no active beams left
break
# Use a min-heap to keep track of candidates for the *next* step
# Store (-score, next_token_id, beam_index) - use negative score for max-heap behavior
candidates = []
# Process current beams (can be batched for efficiency, but simpler loop shown here)
# For batching: stack ys, expand memory, create masks for the batch
for beam_idx, (current_seq, current_score) in enumerate(beams):
if current_seq[0, -1].item() == eos_idx: # Beam already finished
# Add length penalty before storing
penalty = ((current_seq.shape[1]) ** length_penalty_alpha)
final_score = current_score / penalty if penalty > 0 else current_score
heapq.heappush(finished_hypotheses, (final_score, current_seq))
# Prune finished hypotheses if we have enough
while len(finished_hypotheses) > beam_width:
heapq.heappop(finished_hypotheses) # Remove lowest score
continue # Don't expand finished beams
# --- Prepare input for this beam ---
ys = current_seq # [1, current_len]
tgt_seq_len = ys.shape[1]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device)
tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device)
# --- Decode one step ---
# Note: memory and memory_key_padding_mask are reused
decoder_output = transformer_model.decode(
tgt=ys,
memory=memory, # Needs expansion if batching beams
tgt_mask=tgt_mask,
tgt_padding_mask=tgt_padding_mask,
memory_key_padding_mask=memory_key_padding_mask, # Needs expansion if batching
) # [1, current_len, emb_size]
# Get logits for the *next* token
next_token_logits = transformer_model.generator(
decoder_output[:, -1, :]
) # [1, tgt_vocab_size]
# Calculate log probabilities
log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size]
# Get top K candidates for *this* beam
# Adding current_score makes it the total path score
top_k_log_probs, top_k_indices = torch.topk(log_probs + current_score, beam_width, dim=1)
# Add candidates to the list for selection across all beams
for i in range(beam_width):
token_id = top_k_indices[0, i].item()
score = top_k_log_probs[0, i].item()
# Store (-score, token_id, beam_idx) for heap
heapq.heappush(candidates, (-score, token_id, beam_idx))
# Prune candidates heap if it exceeds beam_width * beam_width (can optimize)
# A simpler pruning: keep only top N overall candidates later
# --- Select Top K Beams for Next Step ---
new_beams = []
# Ensure we don't exceed beam_width overall candidates
num_candidates_to_consider = min(len(candidates), beam_width * len(beams)) # Rough upper bound
# Use heap to efficiently get top k candidates overall
top_candidates = heapq.nsmallest(beam_width, candidates) # Get k smallest (-score) -> largest score
added_sequences = set() # Prevent duplicate sequences if paths converge
for neg_score, token_id, beam_idx in top_candidates:
original_seq, _ = beams[beam_idx]
new_seq = torch.cat(
[
original_seq,
torch.ones(1, 1, dtype=torch.long, device=device).fill_(token_id),
],
dim=1,
) # [1, current_len + 1]
# Avoid adding duplicates (optional, but good practice)
seq_tuple = tuple(new_seq.flatten().tolist())
if seq_tuple not in added_sequences:
new_beams.append((new_seq, -neg_score)) # Store positive score
added_sequences.add(seq_tuple)
beams = new_beams # Update active beams
# Early stopping: If top beam is finished and we have enough results
if finished_hypotheses:
# Check if the best possible score from active beams is worse than the worst finished beam
best_active_score = -heapq.nsmallest(1, candidates)[0][0] if candidates else -float('inf')
worst_finished_score = finished_hypotheses[0][0] # Smallest score in min-heap
if len(finished_hypotheses) >= num_return_sequences and best_active_score < worst_finished_score:
logging.debug(f"Beam search early stopping at step {step}")
break
# --- Final Selection ---
# Add any remaining active beams to finished list (if they didn't end with EOS)
for seq, score in beams:
if seq[0, -1].item() != eos_idx:
penalty = ((seq.shape[1]) ** length_penalty_alpha)
final_score = score / penalty if penalty > 0 else score
heapq.heappush(finished_hypotheses, (final_score, seq))
while len(finished_hypotheses) > beam_width:
heapq.heappop(finished_hypotheses)
# Sort finished hypotheses by score (descending) and select top N
# heapq is min-heap, so nlargest gets the best scores
top_hypotheses = heapq.nlargest(num_return_sequences, finished_hypotheses)
# Return list of (sequence_tensor [1, seq_len], score) excluding SOS
return [(seq[:, 1:], score) for score, seq in top_hypotheses]
except RuntimeError as e:
logging.error(f"Runtime error during beam search: {e}", exc_info=True)
if "CUDA out of memory" in str(e) and device.type == "cuda":
gc.collect()
torch.cuda.empty_cache()
return [] # Return empty list on error
except Exception as e:
logging.error(f"Unexpected error during beam search: {e}", exc_info=True)
return []
# --- Translation Function (Handles both Greedy and Beam Search) ---
def translate(
model: pl.LightningModule,
src_sentence: str,
smiles_tokenizer: Tokenizer,
iupac_tokenizer: Tokenizer,
device: torch.device,
max_len: int,
sos_idx: int,
eos_idx: int,
pad_idx: int,
decoding_strategy: str = "Greedy",
beam_width: int = 5,
num_return_sequences: int = 1,
length_penalty_alpha: float = 0.6,
) -> list[tuple[str, float]]: # Returns list of (translation_string, score)
"""
Translates a single SMILES string using the specified decoding strategy.
"""
model.eval()
# --- Tokenize Source ---
try:
smiles_tokenizer.enable_truncation(max_length=max_len)
src_encoded = smiles_tokenizer.encode(src_sentence)
if not src_encoded or not src_encoded.ids:
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
return [("[Encoding Error]", 0.0)]
src_ids = src_encoded.ids
except Exception as e:
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
return [("[Encoding Error]", 0.0)]
# --- Prepare Input Tensor and Mask ---
src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len]
src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
# --- Perform Decoding ---
generation_max_len = config.get("max_len", 256)
results = [] # List to store (tensor, score) tuples
if decoding_strategy == "Greedy":
tgt_tokens_tensor = greedy_decode(
model=model,
src=src,
src_padding_mask=src_padding_mask,
max_len=generation_max_len,
sos_idx=sos_idx,
eos_idx=eos_idx,
device=device,
) # Returns tensor [1, generated_len]
if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0:
results = [(tgt_tokens_tensor, 0.0)] # Assign dummy score 0.0 for greedy
else:
logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}")
return [("[Decoding Error - Empty Output]", 0.0)]
elif decoding_strategy == "Beam Search":
results = beam_search_decode(
model=model,
src=src,
src_padding_mask=src_padding_mask,
max_len=generation_max_len,
sos_idx=sos_idx,
eos_idx=eos_idx,
pad_idx=pad_idx,
device=device,
beam_width=beam_width,
num_return_sequences=num_return_sequences,
length_penalty_alpha=length_penalty_alpha,
) # Returns list of (tensor, score)
if not results:
logging.warning(f"Beam search returned no results for SMILES: {src_sentence}")
return [("[Decoding Error - Empty Output]", 0.0)]
else:
logging.error(f"Unknown decoding strategy: {decoding_strategy}")
return [("[Error: Unknown Strategy]", 0.0)]
# --- Decode Generated Tokens ---
translations = []
for tgt_tokens_tensor, score in results:
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
translations.append(("[Decoding Error - Empty Sequence]", score))
continue
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
try:
# Decode using the target tokenizer, skipping special tokens
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
translations.append((translation, score))
except Exception as e:
logging.error(
f"Error decoding target tokens {tgt_tokens}: {e}",
exc_info=True,
)
translations.append(("[Decoding Error]", score))
return translations
# --- Model/Tokenizer Loading Function (Unchanged from previous version) ---
def load_model_and_tokenizers():
"""Loads tokenizers, config, and model from Hugging Face Hub."""
global model, smiles_tokenizer, iupac_tokenizer, device, config
if model is not None:
logging.info("Model and tokenizers already loaded.")
return
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
try:
# Determine device (Force CPU for stability in typical Space envs, uncomment cuda if needed)
# if torch.cuda.is_available():
# device = torch.device("cuda")
# logging.info("CUDA available, using GPU.")
# else:
device = torch.device("cpu")
logging.info("Using CPU. Modify code to enable GPU if available and desired.")
# Download files
logging.info("Downloading files from Hugging Face Hub...")
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
os.makedirs(cache_dir, exist_ok=True)
logging.info(f"Using cache directory: {cache_dir}")
try:
checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir)
smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir)
iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir)
config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir)
# Ensure enhanced_trainer.py is downloaded or present
try:
hf_hub_download(repo_id=MODEL_REPO_ID, filename="enhanced_trainer.py", cache_dir=cache_dir, local_dir=".") # Download to current dir
logging.info("Downloaded enhanced_trainer.py")
except Exception as download_err:
if os.path.exists("enhanced_trainer.py"):
logging.warning(f"Could not download enhanced_trainer.py (maybe private?), but found local file. Using local. Error: {download_err}")
else:
raise download_err # Re-raise if not found locally either
logging.info("Files downloaded successfully.")
except Exception as e:
logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True)
raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}")
# Load config
logging.info("Loading configuration...")
try:
with open(config_path, "r") as f:
config = json.load(f)
logging.info("Configuration loaded.")
required_keys = ["src_vocab_size", "tgt_vocab_size", "emb_size", "nhead", "ffn_hid_dim", "num_encoder_layers", "num_decoder_layers", "dropout", "max_len", "pad_token_id", "bos_token_id", "eos_token_id"]
missing_keys = [key for key in required_keys if config.get(key) is None]
if missing_keys:
raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}.")
logging.info(f"Using config: { {k: config.get(k) for k in required_keys} }") # Log key values
except Exception as e:
logging.error(f"Error loading or validating config: {e}", exc_info=True)
raise gr.Error(f"Config Error: {e}")
# Load tokenizers
logging.info("Loading tokenizers...")
try:
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
# Basic validation (can add more checks as before)
if smiles_tokenizer.get_vocab_size() != config['src_vocab_size']:
logging.warning(f"SMILES vocab size mismatch: Tokenizer={smiles_tokenizer.get_vocab_size()}, Config={config['src_vocab_size']}")
if iupac_tokenizer.get_vocab_size() != config['tgt_vocab_size']:
logging.warning(f"IUPAC vocab size mismatch: Tokenizer={iupac_tokenizer.get_vocab_size()}, Config={config['tgt_vocab_size']}")
logging.info("Tokenizers loaded.")
except Exception as e:
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
raise gr.Error(f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}")
# Load model
logging.info("Loading model from checkpoint...")
try:
# Ensure config keys match expected arguments of SmilesIupacLitModule.__init__
# Map config keys if necessary, e.g., if config uses 'vocab_size_src' but class expects 'src_vocab_size'
model_hparams = config.copy() # Start with all config params
# Example remapping (adjust if your config/class names differ):
# model_hparams['src_vocab_size'] = model_hparams.pop('vocab_size_src', config['src_vocab_size'])
# model_hparams['tgt_vocab_size'] = model_hparams.pop('vocab_size_tgt', config['tgt_vocab_size'])
# model_hparams['bos_idx'] = model_hparams.pop('bos_token_id', config['bos_token_id'])
# model_hparams['eos_idx'] = model_hparams.pop('eos_token_id', config['eos_token_id'])
# model_hparams['padding_idx'] = model_hparams.pop('pad_token_id', config['pad_token_id'])
# Remove keys from hparams that are not expected by the LitModule's __init__
# This depends on the exact signature of SmilesIupacLitModule
# Common ones to potentially remove if not direct args: max_len (often used elsewhere)
# Check the __init__ signature of SmilesIupacLitModule in enhanced_trainer.py
expected_args = SmilesIupacLitModule.__init__.__code__.co_varnames
hparams_to_pass = {k: v for k, v in model_hparams.items() if k in expected_args}
logging.info(f"Passing hparams to LitModule: {hparams_to_pass.keys()}")
model = SmilesIupacLitModule.load_from_checkpoint(
checkpoint_path,
map_location=device,
# devices=1, # Often not needed for inference loading
strict=False, # Set to False initially if encountering key errors
**hparams_to_pass # Pass relevant hparams from config
)
model.to(device)
model.eval()
model.freeze()
logging.info(f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'.")
except FileNotFoundError:
logging.error(f"Checkpoint file not found: {checkpoint_path}")
raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
except Exception as e:
logging.error(f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
if "size mismatch" in str(e):
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
logging.error(error_detail)
raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
elif "unexpected keyword argument" in str(e) or "missing 1 required positional argument" in str(e):
error_detail = f"Mismatch between config.json keys and SmilesIupacLitModule constructor arguments. Check enhanced_trainer.py and config.json. Error: {e}"
logging.error(error_detail)
raise gr.Error(f"Model Error: {error_detail}")
elif "memory" in str(e).lower():
logging.warning("Potential OOM error during model loading.")
gc.collect()
torch.cuda.empty_cache() if device.type == "cuda" else None
raise gr.Error(f"Model Error: OOM loading model. Check Space resources. Error: {e}")
else:
raise gr.Error(f"Model Error: Failed to load checkpoint. Check logs. Error: {e}")
except gr.Error:
raise # Propagate Gradio errors directly
except Exception as e:
logging.error(f"Unexpected error during loading: {e}", exc_info=True)
raise gr.Error(f"Initialization Error: Unexpected error. Check logs. Error: {e}")
# --- Inference Function for Gradio ---
@spaces.GPU # Uncomment if using GPU and have appropriate hardware tier
def predict_iupac(smiles_string, decoding_strategy, num_beams, num_return_sequences):
"""
Performs SMILES to IUPAC translation using the loaded model and selected strategy.
"""
global model, smiles_tokenizer, iupac_tokenizer, device, config
# --- Input Validation ---
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
logging.error(error_msg)
return f"Initialization Error: {error_msg}"
if not smiles_string or not smiles_string.strip():
return "Error: Please enter a valid SMILES string."
smiles_input = smiles_string.strip()
# Validate SMILES using RDKit
try:
mol = MolFromSmiles(smiles_input)
if mol is None:
return f"Error: Invalid SMILES string provided: '{smiles_input}'"
smiles_input = CanonSmiles(smiles_input) # Use canonical form
logging.info(f"Canonical SMILES: {smiles_input}")
except Exception as e:
logging.error(f"Error during SMILES validation/canonicalization: {e}", exc_info=True)
return f"Error: Could not process SMILES string '{smiles_input}'. RDKit error: {e}"
# Validate beam search parameters
if decoding_strategy == "Beam Search":
if not isinstance(num_beams, int) or num_beams <= 0:
return "Error: Beam width must be a positive integer."
if not isinstance(num_return_sequences, int) or num_return_sequences <= 0:
return "Error: Number of return sequences must be a positive integer."
if num_return_sequences > num_beams:
return f"Error: Number of return sequences ({num_return_sequences}) cannot exceed beam width ({num_beams})."
else:
# Ensure defaults are used for greedy
num_beams = 1
num_return_sequences = 1
try:
# --- Call the core translation logic ---
sos_idx = config["bos_token_id"]
eos_idx = config["eos_token_id"]
pad_idx = config["pad_token_id"]
gen_max_len = config["max_len"]
# Use fixed length penalty for now, could be another slider
length_penalty = 0.6
predicted_results = translate( # Returns list of (name, score)
model=model,
src_sentence=smiles_input,
smiles_tokenizer=smiles_tokenizer,
iupac_tokenizer=iupac_tokenizer,
device=device,
max_len=gen_max_len,
sos_idx=sos_idx,
eos_idx=eos_idx,
pad_idx=pad_idx,
decoding_strategy=decoding_strategy,
beam_width=num_beams,
num_return_sequences=num_return_sequences,
length_penalty_alpha=length_penalty,
)
logging.info(f"Prediction returned {len(predicted_results)} result(s). Strategy: {decoding_strategy}, Beams: {num_beams}, Return: {num_return_sequences}")
# --- Format Output ---
output_lines = []
output_lines.append(f"Input SMILES: {smiles_input}")
output_lines.append(f"Decoding Strategy: {decoding_strategy}")
if decoding_strategy == "Beam Search":
output_lines.append(f"Beam Width: {num_beams}")
output_lines.append(f"Returned Sequences: {len(predicted_results)}")
output_lines.append(f"Length Penalty Alpha: {length_penalty:.2f}")
output_lines.append("\n--- Predictions ---")
if not predicted_results:
output_lines.append("No predictions generated.")
else:
for i, (name, score) in enumerate(predicted_results):
if "[Error]" in name or not name:
output_lines.append(f"{i+1}. Prediction Failed: {name}")
else:
score_info = f"(Score: {score:.4f})" if decoding_strategy == "Beam Search" else ""
output_lines.append(f"{i+1}. {name} {score_info}")
return "\n".join(output_lines)
except RuntimeError as e:
logging.error(f"Runtime error during translation: {e}", exc_info=True)
gc.collect()
if device.type == 'cuda': torch.cuda.empty_cache()
return f"Runtime Error during translation: {e}. Check logs."
except Exception as e:
logging.error(f"Unexpected error during translation: {e}", exc_info=True)
return f"Unexpected Error during translation: {e}. Check logs."
# --- Load Model on App Start ---
try:
load_model_and_tokenizers()
except gr.Error as ge:
# Log the Gradio error but allow interface to load potentially showing the error message
logging.error(f"Gradio Initialization Error during load: {ge}")
# Display error in the UI if possible? Hard to do before UI is built.
# We rely on the predict function checking for loaded components.
except Exception as e:
logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
# This might prevent the app from starting correctly.
# --- Create Gradio Interface ---
title = "SMILES to IUPAC Name Translator"
description = f"""
Translate a SMILES string into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}).
Choose between **Greedy Decoding** (fastest, picks the most likely next word) and **Beam Search Decoding** (explores multiple possibilities, potentially better results, slower).
**Note:** Model loaded on **{str(device).upper() if device else 'N/A'}**. Beam search can be slow, especially with larger beam widths.
Check `config.json` in the repo for model details. SMILES input will be canonicalized using RDKit.
"""
# Use gr.Blocks for more layout control
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1): # Input column
smiles_input = gr.Textbox(
label="SMILES String",
placeholder="Enter SMILES string (e.g., CCO, c1ccccc1)",
lines=2,
)
with gr.Accordion("Decoding Options", open=False): # Options collapsible
decode_strategy = gr.Radio(
["Greedy", "Beam Search"],
label="Decoding Strategy",
value="Greedy",
info="Greedy is faster, Beam Search may be more accurate."
)
beam_width_slider = gr.Slider(
minimum=1,
maximum=20, # Keep max reasonable for performance
step=1,
value=5,
label="Beam Width",
info="Number of beams to explore (Beam Search only)",
visible=False # Initially hidden
)
num_seq_slider = gr.Slider(
minimum=1,
maximum=5, # Keep max reasonable
step=1,
value=1,
label="Number of Results",
info="How many sequences to return (Beam Search only)",
visible=False # Initially hidden
)
submit_btn = gr.Button("Translate", variant="primary")
# --- Logic to show/hide beam search options ---
def update_beam_options(strategy):
is_beam = strategy == "Beam Search"
return {
beam_width_slider: gr.update(visible=is_beam),
num_seq_slider: gr.update(visible=is_beam)
}
decode_strategy.change(
fn=update_beam_options,
inputs=decode_strategy,
outputs=[beam_width_slider, num_seq_slider]
)
with gr.Column(scale=2): # Output column
output_text = gr.Textbox(
label="Translation Results",
lines=10, # More lines for potentially multiple results
show_copy_button=True,
# interactive=False # Output shouldn't be user-editable
)
# --- Define Event Listeners ---
submit_btn.click(
fn=predict_iupac,
inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider],
outputs=output_text,
api_name="translate_smiles"
)
# Trigger on Enter press in the SMILES box
smiles_input.submit(
fn=predict_iupac,
inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider],
outputs=output_text
)
# Add examples
gr.Examples(
examples=[
["CCO", "Greedy", 1, 1],
["c1ccccc1", "Greedy", 1, 1],
["CC(C)Br", "Beam Search", 5, 3],
["C[C@H](O)c1ccccc1", "Beam Search", 10, 5],
["INVALID_SMILES", "Greedy", 1, 1], # Example of invalid input
["N#CC(C)(C)OC(=O)C(C)=C", "Beam Search", 8, 2]
],
inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], # Match inputs order
outputs=output_text, # Output component
fn=predict_iupac, # Function to run for examples
cache_examples=False, # Caching might be tricky with model state
label="Example SMILES & Settings"
)
# --- Launch the App ---
if __name__ == "__main__":
iface.launch()