File size: 22,528 Bytes
aaafea4
 
 
e6c42e6
9bd549c
 
aaafea4
 
 
 
 
9bd549c
df959de
 
aaafea4
 
fe32912
 
 
9bd549c
aaafea4
 
9bd549c
aaafea4
 
 
59543a5
 
 
aaafea4
028e0b0
aaafea4
 
e6c42e6
aaafea4
 
59543a5
028e0b0
 
59543a5
028e0b0
59543a5
 
aaafea4
59543a5
 
 
028e0b0
59543a5
 
aaafea4
 
028e0b0
aaafea4
 
 
 
 
fe32912
9bd549c
 
aaafea4
 
 
 
 
 
 
9bd549c
aaafea4
9bd549c
aaafea4
fe32912
 
aaafea4
 
 
 
59543a5
 
fe32912
aaafea4
fe32912
aaafea4
9bd549c
 
e6c42e6
 
 
aaafea4
 
e6c42e6
9bd549c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6c42e6
 
 
9bd549c
 
 
 
 
e6c42e6
9bd549c
 
 
 
e6c42e6
 
 
 
 
 
 
 
9bd549c
 
 
aaafea4
 
9bd549c
e6c42e6
aaafea4
 
9bd549c
fe32912
59543a5
 
e6c42e6
 
 
aaafea4
9bd549c
 
aaafea4
fe32912
9bd549c
aaafea4
 
 
 
 
 
 
 
 
 
e6c42e6
aaafea4
9bd549c
aaafea4
fe32912
aaafea4
 
 
028e0b0
e6c42e6
 
 
aaafea4
 
 
9bd549c
028e0b0
 
aaafea4
028e0b0
9bd549c
aaafea4
 
59543a5
 
fe32912
9bd549c
fe32912
aaafea4
9bd549c
 
028e0b0
fe32912
 
 
9bd549c
aaafea4
 
 
fe32912
aaafea4
 
9bd549c
aaafea4
9bd549c
aaafea4
 
9bd549c
e6c42e6
 
 
9bd549c
aaafea4
9bd549c
 
 
e6c42e6
9bd549c
 
 
 
 
 
 
aaafea4
fe32912
9bd549c
aaafea4
 
 
fe32912
aaafea4
 
 
 
 
9bd549c
028e0b0
fe32912
 
 
 
 
 
028e0b0
fe32912
 
aaafea4
9bd549c
aaafea4
 
9bd549c
028e0b0
 
 
59543a5
028e0b0
59543a5
 
fe32912
 
 
59543a5
 
fe32912
 
 
59543a5
 
028e0b0
59543a5
aaafea4
 
59543a5
9bd549c
59543a5
 
 
 
 
aaafea4
 
 
 
59543a5
aaafea4
 
 
 
e6c42e6
 
 
 
 
 
 
 
 
 
 
 
aaafea4
9bd549c
028e0b0
e6c42e6
df959de
e6c42e6
 
df959de
e6c42e6
9bd549c
028e0b0
 
 
 
aaafea4
e6c42e6
9bd549c
 
 
 
fe32912
9bd549c
fe32912
 
 
028e0b0
aaafea4
9bd549c
 
aaafea4
 
e6c42e6
 
 
9bd549c
aaafea4
 
9bd549c
 
 
aaafea4
 
 
 
 
 
 
9bd549c
 
028e0b0
 
 
 
 
9bd549c
e6c42e6
 
 
 
 
 
 
 
 
 
 
 
 
 
028e0b0
aaafea4
9bd549c
e6c42e6
 
 
aaafea4
 
 
 
9bd549c
aaafea4
 
9bd549c
269bf54
 
9bd549c
 
e6c42e6
9bd549c
e6c42e6
aaafea4
028e0b0
aaafea4
 
9bd549c
59543a5
028e0b0
59543a5
aaafea4
 
9bd549c
e6c42e6
 
 
aaafea4
e6c42e6
 
 
028e0b0
e6c42e6
fe32912
 
028e0b0
9bd549c
e6c42e6
 
 
 
 
028e0b0
e6c42e6
 
 
aaafea4
e6c42e6
 
9bd549c
 
e6c42e6
 
 
aaafea4
 
9bd549c
 
aaafea4
9bd549c
aaafea4
 
 
 
028e0b0
 
e6c42e6
aaafea4
 
 
e6c42e6
aaafea4
 
028e0b0
aaafea4
9bd549c
fe32912
 
 
9bd549c
028e0b0
e6c42e6
aaafea4
 
 
 
 
9bd549c
028e0b0
 
 
aaafea4
9bd549c
aaafea4
028e0b0
e6c42e6
 
 
 
9bd549c
e6c42e6
aaafea4
fe32912
 
9bd549c
 
59543a5
aaafea4
 
 
 
e6c42e6
aaafea4
 
 
 
e6c42e6
aaafea4
 
028e0b0
aaafea4
 
028e0b0
fe32912
e6c42e6
aaafea4
9bd549c
 
aaafea4
 
9bd549c
 
aaafea4
028e0b0
9bd549c
e6c42e6
aaafea4
 
9bd549c
aaafea4
 
 
59543a5
aaafea4
e6c42e6
9bd549c
aaafea4
e6c42e6
 
 
aaafea4
 
028e0b0
aaafea4
fe32912
e6c42e6
fe32912
aaafea4
 
9bd549c
 
028e0b0
 
9bd549c
028e0b0
aaafea4
 
028e0b0
aaafea4
df959de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
# app.py
import gradio as gr
import torch

# import torch.nn.functional as F # No longer needed for greedy decode directly
import pytorch_lightning as pl
import os
import json
import logging
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import gc
from rdkit.Chem import CanonSmiles


# --- Configuration ---
MODEL_REPO_ID = (
    "AdrianM0/smiles-to-iupac-translator"  # <-- Make sure this is your repo ID
)
CHECKPOINT_FILENAME = "last.ckpt"
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
CONFIG_FILENAME = "config.json"
# --- End Configuration ---

# --- Logging ---
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
try:
    from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask

    logging.info("Successfully imported from enhanced_trainer.py.")
except ImportError as e:
    logging.error(
        f"Failed to import helper code from enhanced_trainer.py: {e}. "
        f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
    )
    raise gr.Error(
        f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
    )
except Exception as e:
    logging.error(
        f"An unexpected error occurred during helper code import: {e}", exc_info=True
    )
    raise gr.Error(
        f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
    )

# --- Global Variables (Load Model Once) ---
model: pl.LightningModule | None = None
smiles_tokenizer: Tokenizer | None = None
iupac_tokenizer: Tokenizer | None = None
device: torch.device | None = None
config: dict | None = None


# --- Greedy Decoding Logic (Locally defined) ---
def greedy_decode(
    model: pl.LightningModule,
    src: torch.Tensor,
    src_padding_mask: torch.Tensor,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    device: torch.device,
) -> torch.Tensor:
    """
    Performs greedy decoding using the LightningModule's model.
    """
    model.eval()  # Ensure model is in evaluation mode
    transformer_model = model.model  # Access the underlying Seq2SeqTransformer

    try:
        with torch.no_grad():
            # --- Encode Source ---
            memory = transformer_model.encode(
                src, src_padding_mask
            )  # [1, src_len, emb_size]
            memory = memory.to(device)
            memory_key_padding_mask = src_padding_mask.to(memory.device)  # [1, src_len]

            # --- Initialize Target Sequence ---
            # Start with the SOS token
            ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
                sos_idx
            )  # [1, 1]

            # --- Decoding Loop ---
            for _ in range(max_len - 1):  # Max length limit
                tgt_seq_len = ys.shape[1]
                tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
                    device
                )  # [curr_len, curr_len]
                # No padding in target during generation yet
                tgt_padding_mask = torch.zeros(
                    ys.shape, dtype=torch.bool, device=device
                )  # [1, curr_len]

                # Decode one step
                decoder_output = transformer_model.decode(
                    tgt=ys,
                    memory=memory,
                    tgt_mask=tgt_mask,
                    tgt_padding_mask=tgt_padding_mask,
                    memory_key_padding_mask=memory_key_padding_mask,
                )  # [1, curr_len, emb_size]

                # Get logits for the *next* token prediction
                next_token_logits = transformer_model.generator(
                    decoder_output[
                        :, -1, :
                    ]  # Use output corresponding to the last input token
                )  # [1, tgt_vocab_size]

                # Find the most likely next token (greedy choice)
                # prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
                # _, next_word_id_tensor = torch.max(prob, dim=1)
                next_word_id_tensor = torch.argmax(next_token_logits, dim=1)  # [1]
                next_word_id = next_word_id_tensor.item()

                # Append the chosen token to the sequence
                ys = torch.cat(
                    [
                        ys,
                        torch.ones(1, 1, dtype=torch.long, device=device).fill_(
                            next_word_id
                        ),
                    ],
                    dim=1,
                )  # [1, current_len + 1]

                # Stop if EOS token is generated
                if next_word_id == eos_idx:
                    break

            # Return the generated sequence (excluding the initial SOS token)
            return ys[:, 1:]  # Shape [1, generated_len]

    except RuntimeError as e:
        logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
        if "CUDA out of memory" in str(e) and device.type == "cuda":
            gc.collect()
            torch.cuda.empty_cache()
        return torch.empty(
            (1, 0), dtype=torch.long, device=device
        )  # Return empty tensor on error
    except Exception as e:
        logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
        return torch.empty((1, 0), dtype=torch.long, device=device)


# --- Translation Function (Using Greedy Decode) ---
def translate(
    model: pl.LightningModule,
    src_sentence: str,
    smiles_tokenizer: Tokenizer,
    iupac_tokenizer: Tokenizer,
    device: torch.device,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    pad_idx: int,
) -> str:  # Returns a single string
    """
    Translates a single SMILES string using greedy decoding.
    """
    model.eval()  # Ensure model is in eval mode

    # --- Tokenize Source ---
    try:
        # Ensure tokenizer has truncation/padding configured if needed, or handle manually
        smiles_tokenizer.enable_truncation(
            max_length=max_len
        )  # Use max_len for source truncation too
        src_encoded = smiles_tokenizer.encode(src_sentence)
        if not src_encoded or not src_encoded.ids:
            logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
            return "[Encoding Error]"
        # Use the truncated IDs directly
        src_ids = src_encoded.ids
    except Exception as e:
        logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
        return "[Encoding Error]"

    # --- Prepare Input Tensor and Mask ---
    src = (
        torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
    )  # [1, src_len]
    # Create padding mask (True where it's a pad token, should be all False here unless tokenizer pads)
    src_padding_mask = (src == pad_idx).to(device)  # [1, src_len]

    # --- Perform Greedy Decoding ---
    # Calls the greedy_decode function defined *above in this file*
    # Note: max_len for generation should come from config if it dictates output length
    generation_max_len = config.get(
        "max_len", 256
    )  # Use config max_len for output limit
    tgt_tokens_tensor = greedy_decode(
        model=model,
        src=src,
        src_padding_mask=src_padding_mask,
        max_len=generation_max_len,  # Use generation limit
        sos_idx=sos_idx,
        eos_idx=eos_idx,
        # pad_idx=pad_idx, # Not needed by greedy_decode internal loop
        device=device,
    )  # Returns a single tensor [1, generated_len]

    # --- Decode Generated Tokens ---
    if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
        logging.warning(
            f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
        )
        return "[Decoding Error - Empty Output]"

    tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
    try:
        # Decode using the target tokenizer, skipping special tokens
        translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
        return translation
    except Exception as e:
        logging.error(
            f"Error decoding target tokens {tgt_tokens}: {e}",
            exc_info=True,
        )
        return "[Decoding Error]"


# --- Model/Tokenizer Loading Function (Unchanged) ---
def load_model_and_tokenizers():
    """Loads tokenizers, config, and model from Hugging Face Hub."""
    global model, smiles_tokenizer, iupac_tokenizer, device, config
    if model is not None:  # Already loaded
        logging.info("Model and tokenizers already loaded.")
        return

    logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
    try:
        # Determine device
        if torch.cuda.is_available():
            logging.warning(
                "CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
            )
            device = torch.device("cpu")
            # device = torch.device("cuda")
            # logging.info("CUDA available, using GPU.")
        else:
            device = torch.device("cpu")
            logging.info("CUDA not available, using CPU.")

        # Download files
        logging.info("Downloading files from Hugging Face Hub...")
        try:
            cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
            os.makedirs(cache_dir, exist_ok=True)
            logging.info(f"Using cache directory: {cache_dir}")

            checkpoint_path = hf_hub_download(
                repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
            )
            smiles_tokenizer_path = hf_hub_download(
                repo_id=MODEL_REPO_ID,
                filename=SMILES_TOKENIZER_FILENAME,
                cache_dir=cache_dir,
            )
            iupac_tokenizer_path = hf_hub_download(
                repo_id=MODEL_REPO_ID,
                filename=IUPAC_TOKENIZER_FILENAME,
                cache_dir=cache_dir,
            )
            config_path = hf_hub_download(
                repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
            )
            logging.info("Files downloaded successfully.")
        except Exception as e:
            logging.error(
                f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
                exc_info=True,
            )
            raise gr.Error(
                f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}"
            )

        # Load config
        logging.info("Loading configuration...")
        try:
            with open(config_path, "r") as f:
                config = json.load(f)
            logging.info("Configuration loaded.")
            # --- Validate essential config keys ---
            required_keys = [
                "src_vocab_size",  # Use the key saved in config
                "tgt_vocab_size",  # Use the key saved in config
                "emb_size",
                "nhead",
                "ffn_hid_dim",
                "num_encoder_layers",
                "num_decoder_layers",
                "dropout",
                "max_len",
                "pad_token_id",
                "bos_token_id",
                "eos_token_id",
            ]
            # Remap if needed (example shown, adjust if your keys differ)
            config_key_mapping = {
                "src_vocab_size": config.get(
                    "src_vocab_size", config.get("src_vocab_size")
                ),
                "tgt_vocab_size": config.get(
                    "tgt_vocab_size", config.get("tgt_vocab_size")
                ),
                # Add other mappings if necessary
            }
            config.update(config_key_mapping)

            missing_keys = [key for key in required_keys if config.get(key) is None]
            if missing_keys:
                raise ValueError(
                    f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
                    f"Ensure these were saved in the hyperparameters during training."
                )

            logging.info(
                f"Using config values: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
                f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
                f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
            )

        except FileNotFoundError:
            logging.error(f"Config file not found: {config_path}")
            raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from config file {config_path}: {e}")
            raise gr.Error(
                f"Config Error: Could not parse '{CONFIG_FILENAME}'. Error: {e}"
            )
        except ValueError as e:
            logging.error(f"Config validation error: {e}")
            raise gr.Error(f"Config Error: {e}")
        except Exception as e:
            logging.error(f"Unexpected error loading config: {e}", exc_info=True)
            raise gr.Error(f"Config Error: Unexpected error. Check logs. Error: {e}")

        # Load tokenizers
        logging.info("Loading tokenizers...")
        try:
            smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
            iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
            logging.info("Tokenizers loaded.")
            # --- Optional: Validate Tokenizer Special Tokens Against Config ---
            # (Keep validation as before, it's still useful)
            pad_token = "<pad>"
            sos_token = "<sos>"
            eos_token = "<eos>"
            unk_token = "<unk>"
            issues = []
            # ... (keep the validation checks as in the original code) ...
            if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
                issues.append(f"SMILES PAD ID mismatch")
            if smiles_tokenizer.token_to_id(unk_token) is None:
                issues.append("SMILES UNK token not found")
            if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
                issues.append(f"IUPAC PAD ID mismatch")
            if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
                issues.append(f"IUPAC SOS ID mismatch")
            if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
                issues.append(f"IUPAC EOS ID mismatch")
            if iupac_tokenizer.token_to_id(unk_token) is None:
                issues.append("IUPAC UNK token not found")
            if issues:
                logging.warning("Tokenizer validation issues: " + "; ".join(issues))

        except Exception as e:
            logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
            raise gr.Error(
                f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
            )

        # Load model
        logging.info("Loading model from checkpoint...")
        try:
            # Use the vocab sizes and hparams from the loaded config
            model = SmilesIupacLitModule.load_from_checkpoint(
                checkpoint_path,
                # Ensure these match the keys used when saving hparams
                src_vocab_size=config["src_vocab_size"],
                tgt_vocab_size=config["tgt_vocab_size"],
                # Pass the whole config dict, load_from_checkpoint will pick what it needs
                # if the keys match the __init__ args or are in hparams
                **config,  # Pass loaded config directly as keyword args
                map_location=device,
                strict=True,  # Start strict, set to False if encountering key errors
            )

            model.to(device)
            model.eval()
            model.freeze()
            logging.info(
                f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
            )

        except FileNotFoundError:
            logging.error(f"Checkpoint file not found: {checkpoint_path}")
            raise gr.Error(
                f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
            )
        except Exception as e:
            logging.error(
                f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
            )
            if "size mismatch" in str(e):
                error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
                logging.error(error_detail)
                raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
            elif "memory" in str(e).lower():
                logging.warning("Potential OOM error during model loading.")
                gc.collect()
                torch.cuda.empty_cache() if device.type == "cuda" else None
                raise gr.Error(
                    f"Model Error: OOM loading model. Check Space resources. Error: {e}"
                )
            else:
                raise gr.Error(
                    f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
                )

    except gr.Error:
        raise
    except Exception as e:
        logging.error(f"Unexpected error during loading: {e}", exc_info=True)
        raise gr.Error(
            f"Initialization Error: Unexpected error. Check logs. Error: {e}"
        )


# --- Inference Function for Gradio (Simplified) ---
def predict_iupac(smiles_string):
    """
    Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
    """
    global model, smiles_tokenizer, iupac_tokenizer, device, config

    if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
        error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
        logging.error(error_msg)
        return f"Error: {error_msg}"  # Return single error string

    if not smiles_string or not smiles_string.strip():
        error_msg = "Error: Please enter a valid SMILES string."
        return f"Error: {error_msg}"  # Return single error string

    smiles_input = smiles_string.strip()

    try:
        # --- Call the core translation logic (greedy) ---
        sos_idx = config["bos_token_id"]
        eos_idx = config["eos_token_id"]
        pad_idx = config["pad_token_id"]
        gen_max_len = config["max_len"]

        predicted_name = translate(  # Returns a single string now
            model=model,
            src_sentence=smiles_input,
            smiles_tokenizer=smiles_tokenizer,
            iupac_tokenizer=iupac_tokenizer,
            device=device,
            max_len=gen_max_len,
            sos_idx=sos_idx,
            eos_idx=eos_idx,
            pad_idx=pad_idx,
        )
        logging.info(f"Prediction returned: {predicted_name}")

        # --- Format Output ---
        if "[Error]" in predicted_name:  # Check for error messages from translate
            output_text = (
                f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
            )
        elif not predicted_name:
            output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
        else:
            output_text = (
                f"Input SMILES: {smiles_input}\n\n"
                f"Predicted IUPAC Name (Greedy Decode):\n"
                f"{predicted_name}"
            )
        return output_text

    except RuntimeError as e:
        logging.error(f"Runtime error during translation: {e}", exc_info=True)
        return f"Error: {error_msg}"  # Return single error string

    except Exception as e:
        logging.error(f"Unexpected error during translation: {e}", exc_info=True)
        error_msg = f"Unexpected Error during translation: {e}"
        return f"Error: {error_msg}"  # Return single error string


# --- Load Model on App Start ---
try:
    load_model_and_tokenizers()
except gr.Error as ge:
    logging.error(f"Gradio Initialization Error: {ge}")
    pass  # Allow Gradio to potentially start with an error message
except Exception as e:
    logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
    # Optionally raise gr.Error here too


# --- Create Gradio Interface (Simplified) ---
title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
description = f"""
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
Translation uses **greedy decoding** (picks the most likely next word at each step).
**Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
"""

# Input component
smiles_input = gr.Textbox(
    label="SMILES String",
    placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
    lines=1,
)
smiles_input = CanonSmiles(smiles_input)
# Output component
output_text = gr.Textbox(
    label="Predicted IUPAC Name",
    lines=3,
    show_copy_button=True,  # Reduced lines slightly
)

# Create the interface instance
iface = gr.Interface(
    fn=predict_iupac,  # The function to call
    inputs=smiles_input,  # Single input component
    outputs=output_text,  # Output component
    title=title,
    description=description,
    allow_flagging="never",
    theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
    article="""
    **Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
    Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations. Greedy decoding can sometimes get stuck in repetitive loops or produce suboptimal results compared to beam search.
    """,
)

# --- Launch the App ---
if __name__ == "__main__":
    iface.launch(share=True)