# app.py import gradio as gr import torch import torch.nn.functional as F # Needed for log_softmax in beam search import pytorch_lightning as pl import os import json import logging from tokenizers import Tokenizer from huggingface_hub import hf_hub_download import gc from rdkit.Chem import CanonSmiles, MolFromSmiles # Added MolFromSmiles for validation import spaces import heapq # For beam search priority queue import math # For log probabilities # --- Configuration --- MODEL_REPO_ID = ( "AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID ) CHECKPOINT_FILENAME = "last.ckpt" SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json" IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json" CONFIG_FILENAME = "config.json" # --- End Configuration --- # --- Logging --- logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # --- Load Helper Code (Only Model Definition and Mask Function Needed) --- try: # Ensure enhanced_trainer.py is present in the repository root from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask logging.info("Successfully imported from enhanced_trainer.py.") except ImportError as e: logging.error( f"Failed to import helper code from enhanced_trainer.py: {e}. " f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'." ) raise gr.Error( f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}" ) except Exception as e: logging.error( f"An unexpected error occurred during helper code import: {e}", exc_info=True ) raise gr.Error( f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}" ) # --- Global Variables (Load Model Once) --- model: pl.LightningModule | None = None smiles_tokenizer: Tokenizer | None = None iupac_tokenizer: Tokenizer | None = None device: torch.device | None = None config: dict | None = None # --- Greedy Decoding Logic (Unchanged) --- def greedy_decode( model: pl.LightningModule, src: torch.Tensor, src_padding_mask: torch.Tensor, max_len: int, sos_idx: int, eos_idx: int, device: torch.device, ) -> torch.Tensor: """ Performs greedy decoding using the LightningModule's model. Returns a tensor of shape [1, sequence_length]. """ model.eval() transformer_model = model.model try: with torch.no_grad(): memory = transformer_model.encode(src, src_padding_mask) memory = memory.to(device) memory_key_padding_mask = src_padding_mask.to(memory.device) ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) for _ in range(max_len - 1): tgt_seq_len = ys.shape[1] tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device) decoder_output = transformer_model.decode( tgt=ys, memory=memory, tgt_mask=tgt_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask, ) next_token_logits = transformer_model.generator(decoder_output[:, -1, :]) next_word_id = torch.argmax(next_token_logits, dim=1).item() ys = torch.cat( [ ys, torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id), ], dim=1, ) if next_word_id == eos_idx: break return ys[:, 1:] # Exclude SOS except RuntimeError as e: logging.error(f"Runtime error during greedy decode: {e}", exc_info=True) if "CUDA out of memory" in str(e) and device.type == "cuda": gc.collect() torch.cuda.empty_cache() return torch.empty((1, 0), dtype=torch.long, device=device) except Exception as e: logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True) return torch.empty((1, 0), dtype=torch.long, device=device) # --- Beam Search Decoding Logic --- def beam_search_decode( model: pl.LightningModule, src: torch.Tensor, src_padding_mask: torch.Tensor, max_len: int, sos_idx: int, eos_idx: int, pad_idx: int, # Needed for padding shorter beams if batching device: torch.device, beam_width: int, num_return_sequences: int = 1, length_penalty_alpha: float = 0.6, # Add length penalty ) -> list[tuple[torch.Tensor, float]]: """ Performs beam search decoding. Returns a list of tuples: (sequence_tensor [1, seq_len], score) """ model.eval() transformer_model = model.model num_return_sequences = min(beam_width, num_return_sequences) # Cannot return more than beam width try: with torch.no_grad(): # --- Encode Source (Once) --- memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size] memory = memory.to(device) memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len] # --- Initialize Beams --- # Each beam: (sequence_tensor [1, current_len], score (log_prob)) initial_beam = ( torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx), 0.0, # Initial score (log probability) ) beams = [initial_beam] finished_hypotheses = [] # Store finished sequences: (score, sequence_tensor) # --- Decoding Loop --- for step in range(max_len - 1): if not beams: # Stop if no active beams left break # Use a min-heap to keep track of candidates for the *next* step # Store (-score, next_token_id, beam_index) - use negative score for max-heap behavior candidates = [] # Process current beams (can be batched for efficiency, but simpler loop shown here) # For batching: stack ys, expand memory, create masks for the batch for beam_idx, (current_seq, current_score) in enumerate(beams): if current_seq[0, -1].item() == eos_idx: # Beam already finished # Add length penalty before storing penalty = ((current_seq.shape[1]) ** length_penalty_alpha) final_score = current_score / penalty if penalty > 0 else current_score heapq.heappush(finished_hypotheses, (final_score, current_seq)) # Prune finished hypotheses if we have enough while len(finished_hypotheses) > beam_width: heapq.heappop(finished_hypotheses) # Remove lowest score continue # Don't expand finished beams # --- Prepare input for this beam --- ys = current_seq # [1, current_len] tgt_seq_len = ys.shape[1] tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device) # --- Decode one step --- # Note: memory and memory_key_padding_mask are reused decoder_output = transformer_model.decode( tgt=ys, memory=memory, # Needs expansion if batching beams tgt_mask=tgt_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask, # Needs expansion if batching ) # [1, current_len, emb_size] # Get logits for the *next* token next_token_logits = transformer_model.generator( decoder_output[:, -1, :] ) # [1, tgt_vocab_size] # Calculate log probabilities log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size] # Get top K candidates for *this* beam # Adding current_score makes it the total path score top_k_log_probs, top_k_indices = torch.topk(log_probs + current_score, beam_width, dim=1) # Add candidates to the list for selection across all beams for i in range(beam_width): token_id = top_k_indices[0, i].item() score = top_k_log_probs[0, i].item() # Store (-score, token_id, beam_idx) for heap heapq.heappush(candidates, (-score, token_id, beam_idx)) # Prune candidates heap if it exceeds beam_width * beam_width (can optimize) # A simpler pruning: keep only top N overall candidates later # --- Select Top K Beams for Next Step --- new_beams = [] # Ensure we don't exceed beam_width overall candidates num_candidates_to_consider = min(len(candidates), beam_width * len(beams)) # Rough upper bound # Use heap to efficiently get top k candidates overall top_candidates = heapq.nsmallest(beam_width, candidates) # Get k smallest (-score) -> largest score added_sequences = set() # Prevent duplicate sequences if paths converge for neg_score, token_id, beam_idx in top_candidates: original_seq, _ = beams[beam_idx] new_seq = torch.cat( [ original_seq, torch.ones(1, 1, dtype=torch.long, device=device).fill_(token_id), ], dim=1, ) # [1, current_len + 1] # Avoid adding duplicates (optional, but good practice) seq_tuple = tuple(new_seq.flatten().tolist()) if seq_tuple not in added_sequences: new_beams.append((new_seq, -neg_score)) # Store positive score added_sequences.add(seq_tuple) beams = new_beams # Update active beams # Early stopping: If top beam is finished and we have enough results if finished_hypotheses: # Check if the best possible score from active beams is worse than the worst finished beam best_active_score = -heapq.nsmallest(1, candidates)[0][0] if candidates else -float('inf') worst_finished_score = finished_hypotheses[0][0] # Smallest score in min-heap if len(finished_hypotheses) >= num_return_sequences and best_active_score < worst_finished_score: logging.debug(f"Beam search early stopping at step {step}") break # --- Final Selection --- # Add any remaining active beams to finished list (if they didn't end with EOS) for seq, score in beams: if seq[0, -1].item() != eos_idx: penalty = ((seq.shape[1]) ** length_penalty_alpha) final_score = score / penalty if penalty > 0 else score heapq.heappush(finished_hypotheses, (final_score, seq)) while len(finished_hypotheses) > beam_width: heapq.heappop(finished_hypotheses) # Sort finished hypotheses by score (descending) and select top N # heapq is min-heap, so nlargest gets the best scores top_hypotheses = heapq.nlargest(num_return_sequences, finished_hypotheses) # Return list of (sequence_tensor [1, seq_len], score) excluding SOS return [(seq[:, 1:], score) for score, seq in top_hypotheses] except RuntimeError as e: logging.error(f"Runtime error during beam search: {e}", exc_info=True) if "CUDA out of memory" in str(e) and device.type == "cuda": gc.collect() torch.cuda.empty_cache() return [] # Return empty list on error except Exception as e: logging.error(f"Unexpected error during beam search: {e}", exc_info=True) return [] # --- Translation Function (Handles both Greedy and Beam Search) --- def translate( model: pl.LightningModule, src_sentence: str, smiles_tokenizer: Tokenizer, iupac_tokenizer: Tokenizer, device: torch.device, max_len: int, sos_idx: int, eos_idx: int, pad_idx: int, decoding_strategy: str = "Greedy", beam_width: int = 5, num_return_sequences: int = 1, length_penalty_alpha: float = 0.6, ) -> list[tuple[str, float]]: # Returns list of (translation_string, score) """ Translates a single SMILES string using the specified decoding strategy. """ model.eval() # --- Tokenize Source --- try: smiles_tokenizer.enable_truncation(max_length=max_len) src_encoded = smiles_tokenizer.encode(src_sentence) if not src_encoded or not src_encoded.ids: logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}") return [("[Encoding Error]", 0.0)] src_ids = src_encoded.ids except Exception as e: logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True) return [("[Encoding Error]", 0.0)] # --- Prepare Input Tensor and Mask --- src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len] src_padding_mask = (src == pad_idx).to(device) # [1, src_len] # --- Perform Decoding --- generation_max_len = config.get("max_len", 256) results = [] # List to store (tensor, score) tuples if decoding_strategy == "Greedy": tgt_tokens_tensor = greedy_decode( model=model, src=src, src_padding_mask=src_padding_mask, max_len=generation_max_len, sos_idx=sos_idx, eos_idx=eos_idx, device=device, ) # Returns tensor [1, generated_len] if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0: results = [(tgt_tokens_tensor, 0.0)] # Assign dummy score 0.0 for greedy else: logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}") return [("[Decoding Error - Empty Output]", 0.0)] elif decoding_strategy == "Beam Search": results = beam_search_decode( model=model, src=src, src_padding_mask=src_padding_mask, max_len=generation_max_len, sos_idx=sos_idx, eos_idx=eos_idx, pad_idx=pad_idx, device=device, beam_width=beam_width, num_return_sequences=num_return_sequences, length_penalty_alpha=length_penalty_alpha, ) # Returns list of (tensor, score) if not results: logging.warning(f"Beam search returned no results for SMILES: {src_sentence}") return [("[Decoding Error - Empty Output]", 0.0)] else: logging.error(f"Unknown decoding strategy: {decoding_strategy}") return [("[Error: Unknown Strategy]", 0.0)] # --- Decode Generated Tokens --- translations = [] for tgt_tokens_tensor, score in results: if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0: translations.append(("[Decoding Error - Empty Sequence]", score)) continue tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist() try: # Decode using the target tokenizer, skipping special tokens translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True) translations.append((translation, score)) except Exception as e: logging.error( f"Error decoding target tokens {tgt_tokens}: {e}", exc_info=True, ) translations.append(("[Decoding Error]", score)) return translations # --- Model/Tokenizer Loading Function (Unchanged from previous version) --- def load_model_and_tokenizers(): """Loads tokenizers, config, and model from Hugging Face Hub.""" global model, smiles_tokenizer, iupac_tokenizer, device, config if model is not None: logging.info("Model and tokenizers already loaded.") return logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...") try: # Determine device (Force CPU for stability in typical Space envs, uncomment cuda if needed) # if torch.cuda.is_available(): # device = torch.device("cuda") # logging.info("CUDA available, using GPU.") # else: device = torch.device("cpu") logging.info("Using CPU. Modify code to enable GPU if available and desired.") # Download files logging.info("Downloading files from Hugging Face Hub...") cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache") os.makedirs(cache_dir, exist_ok=True) logging.info(f"Using cache directory: {cache_dir}") try: checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir) smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir) iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir) config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir) # Ensure enhanced_trainer.py is downloaded or present try: hf_hub_download(repo_id=MODEL_REPO_ID, filename="enhanced_trainer.py", cache_dir=cache_dir, local_dir=".") # Download to current dir logging.info("Downloaded enhanced_trainer.py") except Exception as download_err: if os.path.exists("enhanced_trainer.py"): logging.warning(f"Could not download enhanced_trainer.py (maybe private?), but found local file. Using local. Error: {download_err}") else: raise download_err # Re-raise if not found locally either logging.info("Files downloaded successfully.") except Exception as e: logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True) raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}") # Load config logging.info("Loading configuration...") try: with open(config_path, "r") as f: config = json.load(f) logging.info("Configuration loaded.") required_keys = ["src_vocab_size", "tgt_vocab_size", "emb_size", "nhead", "ffn_hid_dim", "num_encoder_layers", "num_decoder_layers", "dropout", "max_len", "pad_token_id", "bos_token_id", "eos_token_id"] missing_keys = [key for key in required_keys if config.get(key) is None] if missing_keys: raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}.") logging.info(f"Using config: { {k: config.get(k) for k in required_keys} }") # Log key values except Exception as e: logging.error(f"Error loading or validating config: {e}", exc_info=True) raise gr.Error(f"Config Error: {e}") # Load tokenizers logging.info("Loading tokenizers...") try: smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path) iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path) # Basic validation (can add more checks as before) if smiles_tokenizer.get_vocab_size() != config['src_vocab_size']: logging.warning(f"SMILES vocab size mismatch: Tokenizer={smiles_tokenizer.get_vocab_size()}, Config={config['src_vocab_size']}") if iupac_tokenizer.get_vocab_size() != config['tgt_vocab_size']: logging.warning(f"IUPAC vocab size mismatch: Tokenizer={iupac_tokenizer.get_vocab_size()}, Config={config['tgt_vocab_size']}") logging.info("Tokenizers loaded.") except Exception as e: logging.error(f"Failed to load tokenizers: {e}", exc_info=True) raise gr.Error(f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}") # Load model logging.info("Loading model from checkpoint...") try: # Ensure config keys match expected arguments of SmilesIupacLitModule.__init__ # Map config keys if necessary, e.g., if config uses 'vocab_size_src' but class expects 'src_vocab_size' model_hparams = config.copy() # Start with all config params # Example remapping (adjust if your config/class names differ): # model_hparams['src_vocab_size'] = model_hparams.pop('vocab_size_src', config['src_vocab_size']) # model_hparams['tgt_vocab_size'] = model_hparams.pop('vocab_size_tgt', config['tgt_vocab_size']) # model_hparams['bos_idx'] = model_hparams.pop('bos_token_id', config['bos_token_id']) # model_hparams['eos_idx'] = model_hparams.pop('eos_token_id', config['eos_token_id']) # model_hparams['padding_idx'] = model_hparams.pop('pad_token_id', config['pad_token_id']) # Remove keys from hparams that are not expected by the LitModule's __init__ # This depends on the exact signature of SmilesIupacLitModule # Common ones to potentially remove if not direct args: max_len (often used elsewhere) # Check the __init__ signature of SmilesIupacLitModule in enhanced_trainer.py expected_args = SmilesIupacLitModule.__init__.__code__.co_varnames hparams_to_pass = {k: v for k, v in model_hparams.items() if k in expected_args} logging.info(f"Passing hparams to LitModule: {hparams_to_pass.keys()}") model = SmilesIupacLitModule.load_from_checkpoint( checkpoint_path, map_location=device, # devices=1, # Often not needed for inference loading strict=False, # Set to False initially if encountering key errors **hparams_to_pass # Pass relevant hparams from config ) model.to(device) model.eval() model.freeze() logging.info(f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'.") except FileNotFoundError: logging.error(f"Checkpoint file not found: {checkpoint_path}") raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.") except Exception as e: logging.error(f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True) if "size mismatch" in str(e): error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint." logging.error(error_detail) raise gr.Error(f"Model Error: {error_detail} Original error: {e}") elif "unexpected keyword argument" in str(e) or "missing 1 required positional argument" in str(e): error_detail = f"Mismatch between config.json keys and SmilesIupacLitModule constructor arguments. Check enhanced_trainer.py and config.json. Error: {e}" logging.error(error_detail) raise gr.Error(f"Model Error: {error_detail}") elif "memory" in str(e).lower(): logging.warning("Potential OOM error during model loading.") gc.collect() torch.cuda.empty_cache() if device.type == "cuda" else None raise gr.Error(f"Model Error: OOM loading model. Check Space resources. Error: {e}") else: raise gr.Error(f"Model Error: Failed to load checkpoint. Check logs. Error: {e}") except gr.Error: raise # Propagate Gradio errors directly except Exception as e: logging.error(f"Unexpected error during loading: {e}", exc_info=True) raise gr.Error(f"Initialization Error: Unexpected error. Check logs. Error: {e}") # --- Inference Function for Gradio --- @spaces.GPU # Uncomment if using GPU and have appropriate hardware tier def predict_iupac(smiles_string, decoding_strategy, num_beams, num_return_sequences): """ Performs SMILES to IUPAC translation using the loaded model and selected strategy. """ global model, smiles_tokenizer, iupac_tokenizer, device, config # --- Input Validation --- if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]): error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs." logging.error(error_msg) return f"Initialization Error: {error_msg}" if not smiles_string or not smiles_string.strip(): return "Error: Please enter a valid SMILES string." smiles_input = smiles_string.strip() # Validate SMILES using RDKit try: mol = MolFromSmiles(smiles_input) if mol is None: return f"Error: Invalid SMILES string provided: '{smiles_input}'" smiles_input = CanonSmiles(smiles_input) # Use canonical form logging.info(f"Canonical SMILES: {smiles_input}") except Exception as e: logging.error(f"Error during SMILES validation/canonicalization: {e}", exc_info=True) return f"Error: Could not process SMILES string '{smiles_input}'. RDKit error: {e}" # Validate beam search parameters if decoding_strategy == "Beam Search": if not isinstance(num_beams, int) or num_beams <= 0: return "Error: Beam width must be a positive integer." if not isinstance(num_return_sequences, int) or num_return_sequences <= 0: return "Error: Number of return sequences must be a positive integer." if num_return_sequences > num_beams: return f"Error: Number of return sequences ({num_return_sequences}) cannot exceed beam width ({num_beams})." else: # Ensure defaults are used for greedy num_beams = 1 num_return_sequences = 1 try: # --- Call the core translation logic --- sos_idx = config["bos_token_id"] eos_idx = config["eos_token_id"] pad_idx = config["pad_token_id"] gen_max_len = config["max_len"] # Use fixed length penalty for now, could be another slider length_penalty = 0.6 predicted_results = translate( # Returns list of (name, score) model=model, src_sentence=smiles_input, smiles_tokenizer=smiles_tokenizer, iupac_tokenizer=iupac_tokenizer, device=device, max_len=gen_max_len, sos_idx=sos_idx, eos_idx=eos_idx, pad_idx=pad_idx, decoding_strategy=decoding_strategy, beam_width=num_beams, num_return_sequences=num_return_sequences, length_penalty_alpha=length_penalty, ) logging.info(f"Prediction returned {len(predicted_results)} result(s). Strategy: {decoding_strategy}, Beams: {num_beams}, Return: {num_return_sequences}") # --- Format Output --- output_lines = [] output_lines.append(f"Input SMILES: {smiles_input}") output_lines.append(f"Decoding Strategy: {decoding_strategy}") if decoding_strategy == "Beam Search": output_lines.append(f"Beam Width: {num_beams}") output_lines.append(f"Returned Sequences: {len(predicted_results)}") output_lines.append(f"Length Penalty Alpha: {length_penalty:.2f}") output_lines.append("\n--- Predictions ---") if not predicted_results: output_lines.append("No predictions generated.") else: for i, (name, score) in enumerate(predicted_results): if "[Error]" in name or not name: output_lines.append(f"{i+1}. Prediction Failed: {name}") else: score_info = f"(Score: {score:.4f})" if decoding_strategy == "Beam Search" else "" output_lines.append(f"{i+1}. {name} {score_info}") return "\n".join(output_lines) except RuntimeError as e: logging.error(f"Runtime error during translation: {e}", exc_info=True) gc.collect() if device.type == 'cuda': torch.cuda.empty_cache() return f"Runtime Error during translation: {e}. Check logs." except Exception as e: logging.error(f"Unexpected error during translation: {e}", exc_info=True) return f"Unexpected Error during translation: {e}. Check logs." # --- Load Model on App Start --- try: load_model_and_tokenizers() except gr.Error as ge: # Log the Gradio error but allow interface to load potentially showing the error message logging.error(f"Gradio Initialization Error during load: {ge}") # Display error in the UI if possible? Hard to do before UI is built. # We rely on the predict function checking for loaded components. except Exception as e: logging.error(f"Critical error during initial model loading: {e}", exc_info=True) # This might prevent the app from starting correctly. # --- Create Gradio Interface --- title = "SMILES to IUPAC Name Translator" description = f""" Translate a SMILES string into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}). Choose between **Greedy Decoding** (fastest, picks the most likely next word) and **Beam Search Decoding** (explores multiple possibilities, potentially better results, slower). **Note:** Model loaded on **{str(device).upper() if device else 'N/A'}**. Beam search can be slow, especially with larger beam widths. Check `config.json` in the repo for model details. SMILES input will be canonicalized using RDKit. """ # Use gr.Blocks for more layout control with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface: gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Row(): with gr.Column(scale=1): # Input column smiles_input = gr.Textbox( label="SMILES String", placeholder="Enter SMILES string (e.g., CCO, c1ccccc1)", lines=2, ) with gr.Accordion("Decoding Options", open=False): # Options collapsible decode_strategy = gr.Radio( ["Greedy", "Beam Search"], label="Decoding Strategy", value="Greedy", info="Greedy is faster, Beam Search may be more accurate." ) beam_width_slider = gr.Slider( minimum=1, maximum=20, # Keep max reasonable for performance step=1, value=5, label="Beam Width", info="Number of beams to explore (Beam Search only)", visible=False # Initially hidden ) num_seq_slider = gr.Slider( minimum=1, maximum=5, # Keep max reasonable step=1, value=1, label="Number of Results", info="How many sequences to return (Beam Search only)", visible=False # Initially hidden ) submit_btn = gr.Button("Translate", variant="primary") # --- Logic to show/hide beam search options --- def update_beam_options(strategy): is_beam = strategy == "Beam Search" return { beam_width_slider: gr.update(visible=is_beam), num_seq_slider: gr.update(visible=is_beam) } decode_strategy.change( fn=update_beam_options, inputs=decode_strategy, outputs=[beam_width_slider, num_seq_slider] ) with gr.Column(scale=2): # Output column output_text = gr.Textbox( label="Translation Results", lines=10, # More lines for potentially multiple results show_copy_button=True, # interactive=False # Output shouldn't be user-editable ) # --- Define Event Listeners --- submit_btn.click( fn=predict_iupac, inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], outputs=output_text, api_name="translate_smiles" ) # Trigger on Enter press in the SMILES box smiles_input.submit( fn=predict_iupac, inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], outputs=output_text ) # Add examples gr.Examples( examples=[ ["CCO", "Greedy", 1, 1], ["c1ccccc1", "Greedy", 1, 1], ["CC(C)Br", "Beam Search", 5, 3], ["C[C@H](O)c1ccccc1", "Beam Search", 10, 5], ["INVALID_SMILES", "Greedy", 1, 1], # Example of invalid input ["N#CC(C)(C)OC(=O)C(C)=C", "Beam Search", 8, 2] ], inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], # Match inputs order outputs=output_text, # Output component fn=predict_iupac, # Function to run for examples cache_examples=False, # Caching might be tricky with model state label="Example SMILES & Settings" ) # --- Launch the App --- if __name__ == "__main__": iface.launch()