# app.py import gradio as gr import torch import torch.nn.functional as F # <--- Added import import pytorch_lightning as pl # <--- Added import (needed for type hints, model access) import os import json import logging from tokenizers import Tokenizer from huggingface_hub import hf_hub_download import gc # For garbage collection on potential OOM import math # Needed for PositionalEncoding if moved here (or keep in enhanced_trainer) # --- Configuration --- MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator" CHECKPOINT_FILENAME = "last.ckpt" SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json" IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json" CONFIG_FILENAME = "config.json" # --- End Configuration --- # --- Logging --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- Load Helper Code (Only Model Definition Needed) --- try: # We only need the LightningModule definition and the mask function now from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask logging.info("Successfully imported from enhanced_trainer.py.") # We will define beam_search_decode and translate locally in this file # REMOVED: from test_ckpt import beam_search_decode, translate except ImportError as e: logging.error(f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'.") gr.Error(f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}") exit() except Exception as e: logging.error(f"An unexpected error occurred during helper code import: {e}", exc_info=True) gr.Error(f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}") exit() # --- Global Variables (Load Model Once) --- model: pl.LightningModule | None = None # Added type hint smiles_tokenizer: Tokenizer | None = None iupac_tokenizer: Tokenizer | None = None device: torch.device | None = None config: dict | None = None # --- Beam Search Decoding Logic (Moved from test_ckpt.py) --- def beam_search_decode( model: pl.LightningModule, src: torch.Tensor, src_padding_mask: torch.Tensor, max_len: int, sos_idx: int, eos_idx: int, pad_idx: int, # Needed for padding mask check if src has padding device: torch.device, beam_width: int = 5, n_best: int = 5, # Number of top sequences to return length_penalty: float = 0.6 # Alpha for length normalization (0=no penalty, 1=full penalty) ) -> list[torch.Tensor]: """ Performs beam search decoding using the LightningModule's model. (Code copied and pasted from test_ckpt.py) """ # Ensure model is in eval mode (redundant if called after model.eval(), but safe) model.eval() transformer_model = model.model # Access the underlying Seq2SeqTransformer n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences try: with torch.no_grad(): # --- Encode Source --- memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size] memory = memory.to(device) # Ensure memory_key_padding_mask is also on the correct device for decode memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len] # --- Initialize Beams --- initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) # [1, 1] initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1] active_beams = [(initial_beam_seq, initial_beam_score)] finished_beams = [] # --- Decoding Loop --- for step in range(max_len - 1): if not active_beams: break potential_next_beams = [] for current_seq, current_score in active_beams: if current_seq[0, -1].item() == eos_idx: finished_beams.append((current_seq, current_score)) continue tgt_input = current_seq # [1, current_len] tgt_seq_len = tgt_input.shape[1] tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) # [curr_len, curr_len] tgt_padding_mask = torch.zeros(tgt_input.shape, dtype=torch.bool, device=device) # [1, curr_len] decoder_output = transformer_model.decode( tgt=tgt_input, memory=memory, tgt_mask=tgt_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask ) # [1, curr_len, emb_size] next_token_logits = transformer_model.generator(decoder_output[:, -1, :]) # [1, tgt_vocab_size] log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size] topk_log_probs, topk_indices = torch.topk(log_probs + current_score, beam_width, dim=-1) for i in range(beam_width): next_token_id = topk_indices[0, i].item() next_score = topk_log_probs[0, i].reshape(1) # Keep as tensor [1] next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=device) # [1, 1] new_seq = torch.cat([current_seq, next_token_tensor], dim=1) # [1, current_len + 1] potential_next_beams.append((new_seq, next_score)) potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True) active_beams = [] added_count = 0 for seq, score in potential_next_beams: is_finished = seq[0, -1].item() == eos_idx if is_finished: finished_beams.append((seq, score)) elif added_count < beam_width: active_beams.append((seq, score)) added_count += 1 elif added_count >= beam_width: break finished_beams.extend(active_beams) # Apply length penalty and sort # Handle potential division by zero if sequence length is 1 (or 0?) def get_score(beam_tuple): seq, score = beam_tuple seq_len = seq.shape[1] if length_penalty == 0.0 or seq_len <= 1: return score.item() else: # Ensure seq_len is float for pow return score.item() / (float(seq_len) ** length_penalty) finished_beams.sort(key=get_score, reverse=True) # Higher score is better top_sequences = [seq[:, 1:] for seq, score in finished_beams[:n_best]] # seq shape [1, len] -> [1, len-1] return top_sequences except RuntimeError as e: logging.error(f"Runtime error during beam search decode: {e}") if "CUDA out of memory" in str(e): gc.collect(); torch.cuda.empty_cache() return [] # Return empty list on error except Exception as e: logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True) return [] # --- Translation Function (Moved from test_ckpt.py) --- def translate( model: pl.LightningModule, src_sentence: str, smiles_tokenizer: Tokenizer, iupac_tokenizer: Tokenizer, device: torch.device, max_len: int, sos_idx: int, eos_idx: int, pad_idx: int, beam_width: int = 5, n_best: int = 5, length_penalty: float = 0.6 ) -> list[str]: """ Translates a single SMILES string using beam search. (Code copied and pasted from test_ckpt.py) """ model.eval() # Ensure model is in eval mode translations = [] # --- Tokenize Source --- try: src_encoded = smiles_tokenizer.encode(src_sentence) if not src_encoded or not src_encoded.ids: logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}") return ["[Encoding Error]"] * n_best src_ids = src_encoded.ids[:max_len] # Truncate source if not src_ids: logging.warning(f"Source empty after truncation: {src_sentence}") return ["[Encoding Error - Empty Src]"] * n_best except Exception as e: logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}") return ["[Encoding Error]"] * n_best # --- Prepare Input Tensor and Mask --- src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len] src_padding_mask = (src == pad_idx).to(device) # [1, src_len] # --- Perform Beam Search Decoding --- # Calls the beam_search_decode function defined above in this file tgt_tokens_list = beam_search_decode( model=model, src=src, src_padding_mask=src_padding_mask, max_len=max_len, sos_idx=sos_idx, eos_idx=eos_idx, pad_idx=pad_idx, device=device, beam_width=beam_width, n_best=n_best, length_penalty=length_penalty ) # Returns list of tensors # --- Decode Generated Tokens --- if not tgt_tokens_list: logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}") return ["[Decoding Error - Empty Output]"] * n_best for tgt_tokens_tensor in tgt_tokens_list: if tgt_tokens_tensor.numel() > 0: tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist() try: translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True) translations.append(translation) except Exception as e: logging.error(f"Error decoding target tokens {tgt_tokens}: {e}") translations.append("[Decoding Error]") else: translations.append("[Decoding Error - Empty Tensor]") # Pad with error messages if fewer than n_best results were generated while len(translations) < n_best: translations.append("[Decoding Error - Fewer Results]") return translations # --- Model/Tokenizer Loading Function (Unchanged) --- def load_model_and_tokenizers(): """Loads tokenizers, config, and model from Hugging Face Hub.""" global model, smiles_tokenizer, iupac_tokenizer, device, config if model is not None: # Already loaded logging.info("Model and tokenizers already loaded.") return logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...") try: device = torch.device("cpu") logging.info(f"Using device: {device}") # Download files from HF Hub logging.info("Downloading files from Hugging Face Hub...") try: checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME) smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME) iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME) config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME) logging.info("Files downloaded successfully.") except Exception as e: logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True) raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}") # Load config logging.info("Loading configuration...") try: with open(config_path, 'r') as f: config = json.load(f) logging.info("Configuration loaded.") # --- Validate essential config keys --- required_keys = [ 'src_vocab_size', 'tgt_vocab_size', 'emb_size', 'nhead', 'ffn_hid_dim', 'num_encoder_layers', 'num_decoder_layers', 'dropout', 'max_len', 'bos_token_id', 'eos_token_id', 'pad_token_id' ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}") # --- End Validation --- except FileNotFoundError: logging.error(f"Config file not found locally after download attempt: {config_path}") raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo.") except json.JSONDecodeError as e: logging.error(f"Error decoding JSON from config file {config_path}: {e}") raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}") except ValueError as e: logging.error(f"Config validation error: {e}") raise gr.Error(f"Config Error: {e}") # Load tokenizers logging.info("Loading tokenizers...") try: smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path) iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path) logging.info("Tokenizers loaded.") # --- Validate Tokenizer Special Tokens --- # Add more robust checks if necessary if smiles_tokenizer.token_to_id("") != config['pad_token_id'] or \ smiles_tokenizer.token_to_id("") is None: logging.warning("SMILES tokenizer special tokens might not match config or are missing.") if iupac_tokenizer.token_to_id("") != config['pad_token_id'] or \ iupac_tokenizer.token_to_id("") != config['bos_token_id'] or \ iupac_tokenizer.token_to_id("") != config['eos_token_id'] or \ iupac_tokenizer.token_to_id("") is None: logging.warning("IUPAC tokenizer special tokens might not match config or are missing.") # --- End Validation --- except Exception as e: logging.error(f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}", exc_info=True) raise gr.Error(f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}") # Load model logging.info("Loading model from checkpoint...") try: model = SmilesIupacLitModule.load_from_checkpoint( checkpoint_path, src_vocab_size=config['src_vocab_size'], tgt_vocab_size=config['tgt_vocab_size'], map_location=device, hparams_dict=config, strict=False, device="cpu" ) model.to(device) model.eval() model.freeze() logging.info("Model loaded successfully, set to eval mode, frozen, and moved to device.") except FileNotFoundError: logging.error(f"Checkpoint file not found locally after download attempt: {checkpoint_path}") raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.") except Exception as e: logging.error(f"Error loading model from checkpoint {checkpoint_path}: {e}", exc_info=True) if "memory" in str(e).lower(): gc.collect() if device == torch.device("cuda"): torch.cuda.empty_cache() raise gr.Error(f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}") except gr.Error: raise except Exception as e: logging.error(f"Unexpected error during model/tokenizer loading: {e}", exc_info=True) raise gr.Error(f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}") # --- Inference Function for Gradio (Unchanged, calls local translate) --- def predict_iupac(smiles_string, beam_width, n_best, length_penalty): """ Performs SMILES to IUPAC translation using the loaded model and beam search. """ global model, smiles_tokenizer, iupac_tokenizer, device, config if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]): error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs." # Ensure n_best is int for range, default to 1 if conversion fails early try: n_best_int = int(n_best) except: n_best_int = 1 return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)]) if not smiles_string or not smiles_string.strip(): error_msg = "Error: Please enter a valid SMILES string." try: n_best_int = int(n_best) except: n_best_int = 1 return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)]) smiles_input = smiles_string.strip() try: beam_width = int(beam_width) n_best = int(n_best) length_penalty = float(length_penalty) except ValueError as e: error_msg = f"Error: Invalid input parameter type ({e})." return f"1. {error_msg}" # Cannot determine n_best here logging.info(f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})") try: # Calls the translate function defined *above in this file* predicted_names = translate( model=model, src_sentence=smiles_input, smiles_tokenizer=smiles_tokenizer, iupac_tokenizer=iupac_tokenizer, device=device, max_len=config['max_len'], sos_idx=config['bos_token_id'], eos_idx=config['eos_token_id'], pad_idx=config['pad_token_id'], beam_width=beam_width, n_best=n_best, length_penalty=length_penalty ) logging.info(f"Predictions returned: {predicted_names}") if not predicted_names: output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated." else: output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n" output_text += "\n".join([f"{i+1}. {name}" for i, name in enumerate(predicted_names)]) return output_text except RuntimeError as e: logging.error(f"Runtime error during translation: {e}", exc_info=True) error_msg = f"Runtime Error during translation: {e}" if "memory" in str(e).lower(): gc.collect() if device == torch.device("cuda"): torch.cuda.empty_cache() error_msg += " (Potential OOM)" return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)]) except Exception as e: logging.error(f"Unexpected error during translation: {e}", exc_info=True) error_msg = f"Unexpected Error during translation: {e}" return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)]) # --- Load Model on App Start (Unchanged) --- try: load_model_and_tokenizers() except gr.Error: pass # Error already raised for Gradio UI except Exception as e: logging.error(f"Critical error during initial model loading sequence: {e}", exc_info=True) gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.") # --- Create Gradio Interface (Unchanged) --- title = "SMILES to IUPAC Name Translator" description = f""" Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model and beam search decoding. Model repository: {MODEL_REPO_ID}. Adjust beam search parameters below. Higher beam width explores more possibilities but is slower. Length penalty influences the preference for shorter/longer names. """ examples = [ ["CCO", 5, 3, 0.6], ["C1=CC=CC=C1", 5, 3, 0.6], ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen ["CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C", 5, 1, 0.6], # Complex example ["INVALID_SMILES", 5, 1, 0.6], ] smiles_input = gr.Textbox( label="SMILES String", placeholder="Enter SMILES string here (e.g., CCO for Ethanol)", lines=1 ) beam_width_input = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Beam Width (k)", info="Number of sequences to keep at each decoding step (higher = more exploration, slower)." ) n_best_input = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Results (n_best)", info="How many top-scoring sequences to return (must be <= Beam Width)." ) length_penalty_input = gr.Slider( minimum=0.0, maximum=2.0, value=0.6, step=0.1, label="Length Penalty (alpha)", info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty." ) output_text = gr.Textbox( label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True ) iface = gr.Interface( fn=predict_iupac, inputs=[smiles_input, beam_width_input, n_best_input, length_penalty_input], outputs=output_text, title=title, description=description, examples=examples, allow_flagging="never", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"), article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results." ) # --- Launch the App (Unchanged) --- if __name__ == "__main__": iface.launch(share=True)