File size: 21,984 Bytes
aaafea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
# app.py
import gradio as gr
import torch
import torch.nn.functional as F # <--- Added import
import pytorch_lightning as pl   # <--- Added import (needed for type hints, model access)
import os
import json
import logging
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import gc # For garbage collection on potential OOM
import math # Needed for PositionalEncoding if moved here (or keep in enhanced_trainer)

# --- Configuration ---
MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator"
CHECKPOINT_FILENAME = "last.ckpt"
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
CONFIG_FILENAME = "config.json"
# --- End Configuration ---

# --- Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Load Helper Code (Only Model Definition Needed) ---
try:
    # We only need the LightningModule definition and the mask function now
    from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
    logging.info("Successfully imported from enhanced_trainer.py.")

    # We will define beam_search_decode and translate locally in this file
    # REMOVED: from test_ckpt import beam_search_decode, translate

except ImportError as e:
    logging.error(f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'.")
    gr.Error(f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}")
    exit()
except Exception as e:
    logging.error(f"An unexpected error occurred during helper code import: {e}", exc_info=True)
    gr.Error(f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}")
    exit()

# --- Global Variables (Load Model Once) ---
model: pl.LightningModule | None = None # Added type hint
smiles_tokenizer: Tokenizer | None = None
iupac_tokenizer: Tokenizer | None = None
device: torch.device | None = None
config: dict | None = None

# --- Beam Search Decoding Logic (Moved from test_ckpt.py) ---

def beam_search_decode(
    model: pl.LightningModule,
    src: torch.Tensor,
    src_padding_mask: torch.Tensor,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    pad_idx: int, # Needed for padding mask check if src has padding
    device: torch.device,
    beam_width: int = 5,
    n_best: int = 5, # Number of top sequences to return
    length_penalty: float = 0.6 # Alpha for length normalization (0=no penalty, 1=full penalty)
) -> list[torch.Tensor]:
    """
    Performs beam search decoding using the LightningModule's model.
    (Code copied and pasted from test_ckpt.py)
    """
    # Ensure model is in eval mode (redundant if called after model.eval(), but safe)
    model.eval()
    transformer_model = model.model # Access the underlying Seq2SeqTransformer
    n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences

    try:
        with torch.no_grad():
            # --- Encode Source ---
            memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size]
            memory = memory.to(device)
            # Ensure memory_key_padding_mask is also on the correct device for decode
            memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]

            # --- Initialize Beams ---
            initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) # [1, 1]
            initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
            active_beams = [(initial_beam_seq, initial_beam_score)]
            finished_beams = []

            # --- Decoding Loop ---
            for step in range(max_len - 1):
                if not active_beams:
                    break

                potential_next_beams = []
                for current_seq, current_score in active_beams:
                    if current_seq[0, -1].item() == eos_idx:
                        finished_beams.append((current_seq, current_score))
                        continue

                    tgt_input = current_seq # [1, current_len]
                    tgt_seq_len = tgt_input.shape[1]
                    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) # [curr_len, curr_len]
                    tgt_padding_mask = torch.zeros(tgt_input.shape, dtype=torch.bool, device=device) # [1, curr_len]

                    decoder_output = transformer_model.decode(
                        tgt=tgt_input,
                        memory=memory,
                        tgt_mask=tgt_mask,
                        tgt_padding_mask=tgt_padding_mask,
                        memory_key_padding_mask=memory_key_padding_mask
                    ) # [1, curr_len, emb_size]

                    next_token_logits = transformer_model.generator(decoder_output[:, -1, :]) # [1, tgt_vocab_size]
                    log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size]

                    topk_log_probs, topk_indices = torch.topk(log_probs + current_score, beam_width, dim=-1)

                    for i in range(beam_width):
                        next_token_id = topk_indices[0, i].item()
                        next_score = topk_log_probs[0, i].reshape(1) # Keep as tensor [1]
                        next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=device) # [1, 1]
                        new_seq = torch.cat([current_seq, next_token_tensor], dim=1) # [1, current_len + 1]
                        potential_next_beams.append((new_seq, next_score))

                potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)

                active_beams = []
                added_count = 0
                for seq, score in potential_next_beams:
                     is_finished = seq[0, -1].item() == eos_idx
                     if is_finished:
                         finished_beams.append((seq, score))
                     elif added_count < beam_width:
                         active_beams.append((seq, score))
                         added_count += 1
                     elif added_count >= beam_width:
                         break

            finished_beams.extend(active_beams)

            # Apply length penalty and sort
            # Handle potential division by zero if sequence length is 1 (or 0?)
            def get_score(beam_tuple):
                seq, score = beam_tuple
                seq_len = seq.shape[1]
                if length_penalty == 0.0 or seq_len <= 1:
                    return score.item()
                else:
                    # Ensure seq_len is float for pow
                    return score.item() / (float(seq_len) ** length_penalty)

            finished_beams.sort(key=get_score, reverse=True) # Higher score is better

            top_sequences = [seq[:, 1:] for seq, score in finished_beams[:n_best]] # seq shape [1, len] -> [1, len-1]
            return top_sequences

    except RuntimeError as e:
        logging.error(f"Runtime error during beam search decode: {e}")
        if "CUDA out of memory" in str(e):
            gc.collect(); torch.cuda.empty_cache()
        return [] # Return empty list on error
    except Exception as e:
        logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
        return []

# --- Translation Function (Moved from test_ckpt.py) ---

def translate(
    model: pl.LightningModule,
    src_sentence: str,
    smiles_tokenizer: Tokenizer,
    iupac_tokenizer: Tokenizer,
    device: torch.device,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    pad_idx: int,
    beam_width: int = 5,
    n_best: int = 5,
    length_penalty: float = 0.6
) -> list[str]:
    """
    Translates a single SMILES string using beam search.
    (Code copied and pasted from test_ckpt.py)
    """
    model.eval() # Ensure model is in eval mode
    translations = []

    # --- Tokenize Source ---
    try:
        src_encoded = smiles_tokenizer.encode(src_sentence)
        if not src_encoded or not src_encoded.ids:
            logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
            return ["[Encoding Error]"] * n_best
        src_ids = src_encoded.ids[:max_len] # Truncate source
        if not src_ids:
             logging.warning(f"Source empty after truncation: {src_sentence}")
             return ["[Encoding Error - Empty Src]"] * n_best
    except Exception as e:
        logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
        return ["[Encoding Error]"] * n_best

    # --- Prepare Input Tensor and Mask ---
    src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len]
    src_padding_mask = (src == pad_idx).to(device) # [1, src_len]

    # --- Perform Beam Search Decoding ---
    # Calls the beam_search_decode function defined above in this file
    tgt_tokens_list = beam_search_decode(
        model=model,
        src=src,
        src_padding_mask=src_padding_mask,
        max_len=max_len,
        sos_idx=sos_idx,
        eos_idx=eos_idx,
        pad_idx=pad_idx,
        device=device,
        beam_width=beam_width,
        n_best=n_best,
        length_penalty=length_penalty
    ) # Returns list of tensors

    # --- Decode Generated Tokens ---
    if not tgt_tokens_list:
         logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
         return ["[Decoding Error - Empty Output]"] * n_best

    for tgt_tokens_tensor in tgt_tokens_list:
        if tgt_tokens_tensor.numel() > 0:
            tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
            try:
                translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
                translations.append(translation)
            except Exception as e:
                logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
                translations.append("[Decoding Error]")
        else:
            translations.append("[Decoding Error - Empty Tensor]")

    # Pad with error messages if fewer than n_best results were generated
    while len(translations) < n_best:
        translations.append("[Decoding Error - Fewer Results]")

    return translations


# --- Model/Tokenizer Loading Function (Unchanged) ---
def load_model_and_tokenizers():
    """Loads tokenizers, config, and model from Hugging Face Hub."""
    global model, smiles_tokenizer, iupac_tokenizer, device, config
    if model is not None: # Already loaded
        logging.info("Model and tokenizers already loaded.")
        return

    logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
    try:
        device = torch.device("cpu")
        logging.info(f"Using device: {device}")

        # Download files from HF Hub
        logging.info("Downloading files from Hugging Face Hub...")
        try:
            checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME)
            smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME)
            iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME)
            config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME)
            logging.info("Files downloaded successfully.")
        except Exception as e:
            logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True)
            raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}")

        # Load config
        logging.info("Loading configuration...")
        try:
            with open(config_path, 'r') as f:
                config = json.load(f)
            logging.info("Configuration loaded.")
            # --- Validate essential config keys ---
            required_keys = [
                'src_vocab_size', 'tgt_vocab_size', 'emb_size', 'nhead',
                'ffn_hid_dim', 'num_encoder_layers', 'num_decoder_layers',
                'dropout', 'max_len', 'bos_token_id', 'eos_token_id', 'pad_token_id'
            ]
            missing_keys = [key for key in required_keys if key not in config]
            if missing_keys:
                raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}")
            # --- End Validation ---
        except FileNotFoundError:
             logging.error(f"Config file not found locally after download attempt: {config_path}")
             raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo.")
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from config file {config_path}: {e}")
            raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}")
        except ValueError as e:
            logging.error(f"Config validation error: {e}")
            raise gr.Error(f"Config Error: {e}")


        # Load tokenizers
        logging.info("Loading tokenizers...")
        try:
            smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
            iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
            logging.info("Tokenizers loaded.")
            # --- Validate Tokenizer Special Tokens ---
            # Add more robust checks if necessary
            if smiles_tokenizer.token_to_id("<pad>") != config['pad_token_id'] or \
               smiles_tokenizer.token_to_id("<unk>") is None:
                 logging.warning("SMILES tokenizer special tokens might not match config or are missing.")
            if iupac_tokenizer.token_to_id("<pad>") != config['pad_token_id'] or \
               iupac_tokenizer.token_to_id("<sos>") != config['bos_token_id'] or \
               iupac_tokenizer.token_to_id("<eos>") != config['eos_token_id'] or \
               iupac_tokenizer.token_to_id("<unk>") is None:
                 logging.warning("IUPAC tokenizer special tokens might not match config or are missing.")
            # --- End Validation ---
        except Exception as e:
            logging.error(f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}", exc_info=True)
            raise gr.Error(f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}")

        # Load model
        logging.info("Loading model from checkpoint...")
        try:
            model = SmilesIupacLitModule.load_from_checkpoint(
                checkpoint_path,
                src_vocab_size=config['src_vocab_size'],
                tgt_vocab_size=config['tgt_vocab_size'],
                map_location=device,
                hparams_dict=config,
                strict=False,
                device="cpu"
            )
            model.to(device)
            model.eval()
            model.freeze()
            logging.info("Model loaded successfully, set to eval mode, frozen, and moved to device.")

        except FileNotFoundError:
            logging.error(f"Checkpoint file not found locally after download attempt: {checkpoint_path}")
            raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
        except Exception as e:
            logging.error(f"Error loading model from checkpoint {checkpoint_path}: {e}", exc_info=True)
            if "memory" in str(e).lower():
                gc.collect()
                if device == torch.device("cuda"):
                    torch.cuda.empty_cache()
            raise gr.Error(f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}")

    except gr.Error:
         raise
    except Exception as e:
        logging.error(f"Unexpected error during model/tokenizer loading: {e}", exc_info=True)
        raise gr.Error(f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}")


# --- Inference Function for Gradio (Unchanged, calls local translate) ---
def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
    """
    Performs SMILES to IUPAC translation using the loaded model and beam search.
    """
    global model, smiles_tokenizer, iupac_tokenizer, device, config

    if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
        error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs."
        # Ensure n_best is int for range, default to 1 if conversion fails early
        try: n_best_int = int(n_best)
        except: n_best_int = 1
        return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)])

    if not smiles_string or not smiles_string.strip():
        error_msg = "Error: Please enter a valid SMILES string."
        try: n_best_int = int(n_best)
        except: n_best_int = 1
        return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)])

    smiles_input = smiles_string.strip()
    try:
        beam_width = int(beam_width)
        n_best = int(n_best)
        length_penalty = float(length_penalty)
    except ValueError as e:
         error_msg = f"Error: Invalid input parameter type ({e})."
         return f"1. {error_msg}" # Cannot determine n_best here

    logging.info(f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})")

    try:
        # Calls the translate function defined *above in this file*
        predicted_names = translate(
            model=model,
            src_sentence=smiles_input,
            smiles_tokenizer=smiles_tokenizer,
            iupac_tokenizer=iupac_tokenizer,
            device=device,
            max_len=config['max_len'],
            sos_idx=config['bos_token_id'],
            eos_idx=config['eos_token_id'],
            pad_idx=config['pad_token_id'],
            beam_width=beam_width,
            n_best=n_best,
            length_penalty=length_penalty
        )
        logging.info(f"Predictions returned: {predicted_names}")

        if not predicted_names:
             output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated."
        else:
            output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
            output_text += "\n".join([f"{i+1}. {name}" for i, name in enumerate(predicted_names)])

        return output_text

    except RuntimeError as e:
        logging.error(f"Runtime error during translation: {e}", exc_info=True)
        error_msg = f"Runtime Error during translation: {e}"
        if "memory" in str(e).lower():
            gc.collect()
            if device == torch.device("cuda"):
                torch.cuda.empty_cache()
            error_msg += " (Potential OOM)"
        return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)])

    except Exception as e:
        logging.error(f"Unexpected error during translation: {e}", exc_info=True)
        error_msg = f"Unexpected Error during translation: {e}"
        return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)])


# --- Load Model on App Start (Unchanged) ---
try:
    load_model_and_tokenizers()
except gr.Error:
    pass # Error already raised for Gradio UI
except Exception as e:
    logging.error(f"Critical error during initial model loading sequence: {e}", exc_info=True)
    gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")


# --- Create Gradio Interface (Unchanged) ---
title = "SMILES to IUPAC Name Translator"
description = f"""
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model and beam search decoding.
Model repository: <a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>{MODEL_REPO_ID}</a>.
Adjust beam search parameters below. Higher beam width explores more possibilities but is slower. Length penalty influences the preference for shorter/longer names.
"""

examples = [
    ["CCO", 5, 3, 0.6],
    ["C1=CC=CC=C1", 5, 3, 0.6],
    ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
    ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
    ["CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C", 5, 1, 0.6], # Complex example
    ["INVALID_SMILES", 5, 1, 0.6],
]

smiles_input = gr.Textbox(
    label="SMILES String",
    placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
    lines=1
)
beam_width_input = gr.Slider(
    minimum=1,
    maximum=10,
    value=5,
    step=1,
    label="Beam Width (k)",
    info="Number of sequences to keep at each decoding step (higher = more exploration, slower)."
)
n_best_input = gr.Slider(
    minimum=1,
    maximum=10,
    value=3,
    step=1,
    label="Number of Results (n_best)",
    info="How many top-scoring sequences to return (must be <= Beam Width)."
)
length_penalty_input = gr.Slider(
    minimum=0.0,
    maximum=2.0,
    value=0.6,
    step=0.1,
    label="Length Penalty (alpha)",
    info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty."
)
output_text = gr.Textbox(
    label="Predicted IUPAC Name(s)",
    lines=5,
    show_copy_button=True
)

iface = gr.Interface(
    fn=predict_iupac,
    inputs=[smiles_input, beam_width_input, n_best_input, length_penalty_input],
    outputs=output_text,
    title=title,
    description=description,
    examples=examples,
    allow_flagging="never",
    theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
    article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results."
)

# --- Launch the App (Unchanged) ---
if __name__ == "__main__":
    iface.launch(share=True)