File size: 36,430 Bytes
aaafea4
 
 
fe32912
 
aaafea4
 
 
 
 
fe32912
 
aaafea4
 
028e0b0
fe32912
 
 
 
aaafea4
 
fe32912
 
 
aaafea4
 
 
59543a5
 
 
aaafea4
028e0b0
aaafea4
028e0b0
 
aaafea4
fe32912
aaafea4
 
028e0b0
 
aaafea4
 
59543a5
028e0b0
 
59543a5
028e0b0
 
59543a5
 
aaafea4
59543a5
 
 
028e0b0
59543a5
 
aaafea4
 
028e0b0
aaafea4
 
 
 
 
fe32912
028e0b0
aaafea4
 
 
 
 
 
 
028e0b0
aaafea4
 
028e0b0
 
aaafea4
 
 
028e0b0
aaafea4
fe32912
 
028e0b0
aaafea4
 
 
 
59543a5
 
fe32912
aaafea4
fe32912
aaafea4
 
59543a5
 
fe32912
 
aaafea4
 
 
 
 
 
 
 
 
 
028e0b0
aaafea4
fe32912
aaafea4
 
 
028e0b0
fe32912
aaafea4
59543a5
 
fe32912
028e0b0
59543a5
 
fe32912
aaafea4
028e0b0
aaafea4
 
 
 
 
59543a5
fe32912
aaafea4
028e0b0
59543a5
fe32912
 
 
 
028e0b0
 
59543a5
 
fe32912
 
 
 
aaafea4
028e0b0
59543a5
028e0b0
fe32912
aaafea4
028e0b0
aaafea4
 
028e0b0
fe32912
 
 
59543a5
 
fe32912
59543a5
 
fe32912
aaafea4
 
028e0b0
 
aaafea4
 
028e0b0
aaafea4
fe32912
aaafea4
fe32912
 
 
 
 
028e0b0
59543a5
 
fe32912
 
 
028e0b0
fe32912
 
028e0b0
 
 
 
 
fe32912
 
 
028e0b0
 
 
aaafea4
 
 
028e0b0
aaafea4
 
028e0b0
 
aaafea4
 
028e0b0
fe32912
 
 
028e0b0
 
 
aaafea4
fe32912
 
 
aaafea4
028e0b0
59543a5
fe32912
 
 
 
aaafea4
 
 
028e0b0
fe32912
59543a5
 
fe32912
aaafea4
 
 
 
fe32912
028e0b0
aaafea4
 
 
 
 
 
 
 
 
 
 
 
59543a5
aaafea4
 
 
028e0b0
aaafea4
fe32912
aaafea4
fe32912
aaafea4
 
 
028e0b0
 
aaafea4
 
 
 
028e0b0
 
 
aaafea4
028e0b0
aaafea4
 
 
59543a5
 
fe32912
028e0b0
fe32912
aaafea4
 
028e0b0
 
fe32912
 
 
aaafea4
 
 
 
fe32912
aaafea4
 
 
 
 
 
59543a5
fe32912
aaafea4
 
 
59543a5
028e0b0
59543a5
aaafea4
028e0b0
 
aaafea4
 
028e0b0
59543a5
 
 
aaafea4
 
fe32912
 
 
 
aaafea4
 
fe32912
 
 
aaafea4
 
 
 
 
 
 
 
fe32912
028e0b0
aaafea4
 
 
fe32912
aaafea4
 
 
 
 
028e0b0
 
 
fe32912
 
 
 
 
 
 
028e0b0
fe32912
 
aaafea4
 
 
 
028e0b0
fe32912
 
 
028e0b0
 
 
59543a5
028e0b0
59543a5
 
fe32912
 
 
59543a5
 
fe32912
 
 
59543a5
 
028e0b0
59543a5
aaafea4
 
59543a5
028e0b0
59543a5
 
 
 
 
aaafea4
 
 
 
59543a5
aaafea4
 
 
028e0b0
 
 
aaafea4
028e0b0
fe32912
 
028e0b0
59543a5
 
 
 
 
 
fe32912
028e0b0
 
fe32912
 
 
aaafea4
028e0b0
 
fe32912
 
 
 
 
 
028e0b0
 
 
 
 
 
 
fe32912
 
 
 
 
 
 
 
 
028e0b0
 
 
 
 
aaafea4
028e0b0
 
 
 
 
fe32912
028e0b0
 
fe32912
028e0b0
fe32912
 
 
028e0b0
 
fe32912
 
 
 
 
028e0b0
aaafea4
fe32912
 
 
 
 
 
aaafea4
 
fe32912
 
 
 
aaafea4
 
fe32912
 
 
 
 
 
 
aaafea4
 
 
 
 
 
 
028e0b0
fe32912
028e0b0
 
 
 
 
 
 
fe32912
 
 
028e0b0
 
 
 
fe32912
 
 
028e0b0
fe32912
 
 
028e0b0
fe32912
 
 
028e0b0
 
 
 
fe32912
 
 
 
 
028e0b0
aaafea4
59543a5
 
 
 
 
 
 
aaafea4
 
 
 
028e0b0
 
aaafea4
 
028e0b0
 
 
 
fe32912
 
 
028e0b0
aaafea4
028e0b0
 
aaafea4
 
fe32912
59543a5
028e0b0
59543a5
aaafea4
 
fe32912
 
 
 
 
 
aaafea4
59543a5
fe32912
 
59543a5
028e0b0
 
fe32912
 
 
 
 
 
028e0b0
 
aaafea4
fe32912
 
 
 
 
028e0b0
fe32912
 
 
aaafea4
fe32912
59543a5
fe32912
 
 
 
 
 
 
aaafea4
 
028e0b0
fbe0ff9
aaafea4
 
028e0b0
aaafea4
 
 
 
028e0b0
 
 
fe32912
 
 
 
59543a5
aaafea4
 
 
fe32912
 
 
 
59543a5
aaafea4
 
028e0b0
 
aaafea4
028e0b0
 
 
fe32912
 
 
aaafea4
028e0b0
 
 
 
aaafea4
 
028e0b0
 
fe32912
 
 
 
028e0b0
aaafea4
 
 
 
 
 
fe32912
028e0b0
 
 
aaafea4
 
fbe0ff9
aaafea4
 
 
028e0b0
aaafea4
028e0b0
aaafea4
028e0b0
 
fe32912
 
 
 
59543a5
028e0b0
59543a5
028e0b0
 
fe32912
aaafea4
 
 
 
 
 
 
 
fe32912
 
028e0b0
 
59543a5
aaafea4
 
 
 
59543a5
aaafea4
 
028e0b0
 
 
aaafea4
 
028e0b0
fe32912
 
 
 
aaafea4
028e0b0
fe32912
 
 
028e0b0
 
aaafea4
 
028e0b0
aaafea4
 
028e0b0
 
 
aaafea4
 
028e0b0
aaafea4
fe32912
 
 
 
028e0b0
 
fe32912
aaafea4
 
028e0b0
aaafea4
 
 
59543a5
aaafea4
028e0b0
aaafea4
fe32912
 
 
 
 
 
aaafea4
 
fe32912
 
 
 
 
 
aaafea4
fbe0ff9
aaafea4
59543a5
aaafea4
 
028e0b0
aaafea4
fe32912
 
028e0b0
 
 
 
fe32912
aaafea4
 
fe32912
 
 
028e0b0
 
 
 
 
 
 
aaafea4
 
028e0b0
aaafea4
fbe0ff9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
# app.py
import gradio as gr
import torch
import torch.nn.functional as F  # Needed for beam search log_softmax
import pytorch_lightning as pl  # Needed for LightningModule and loading
import os
import json
import logging
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import gc  # For garbage collection on potential OOM
import math  # Potentially needed by imported classes

# --- Configuration ---
# Ensure these match the files uploaded to your Hugging Face Hub repository
MODEL_REPO_ID = (
    "AdrianM0/smiles-to-iupac-translator"  # <-- Make sure this is your repo ID
)
CHECKPOINT_FILENAME = "last.ckpt"  # Or "best_model.ckpt" or whatever you uploaded
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
CONFIG_FILENAME = (
    "config.json"  # Assumes you saved hparams to config.json during/after training
)
# --- End Configuration ---

# --- Logging ---
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
try:
    # We need the LightningModule definition and the mask function
    # Ensure enhanced_trainer.py is present in the root of your HF Repo
    from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask

    logging.info("Successfully imported from enhanced_trainer.py.")

    # REMOVED: Redundant import from test_ckpt as functions are defined below
    # from test_ckpt import beam_search_decode, translate

except ImportError as e:
    logging.error(
        f"Failed to import helper code from enhanced_trainer.py: {e}. "
        f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
    )
    # Raise error visible in Gradio UI and logs
    raise gr.Error(
        f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
    )
except Exception as e:
    logging.error(
        f"An unexpected error occurred during helper code import: {e}", exc_info=True
    )
    raise gr.Error(
        f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
    )

# --- Global Variables (Load Model Once) ---
model: pl.LightningModule | None = None
smiles_tokenizer: Tokenizer | None = None
iupac_tokenizer: Tokenizer | None = None
device: torch.device | None = None
config: dict | None = None


# --- Beam Search Decoding Logic (Locally defined) ---
def beam_search_decode(
    model: pl.LightningModule,
    src: torch.Tensor,
    src_padding_mask: torch.Tensor,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    pad_idx: int,
    device: torch.device,
    beam_width: int = 5,
    n_best: int = 5,
    length_penalty: float = 0.6,
) -> list[torch.Tensor]:
    """
    Performs beam search decoding using the LightningModule's model.
    (Ensures this code is self-contained within app.py or correctly imported)
    """
    model.eval()  # Ensure model is in evaluation mode
    transformer_model = model.model  # Access the underlying Seq2SeqTransformer
    n_best = min(n_best, beam_width)

    try:
        with torch.no_grad():
            # --- Encode Source ---
            memory = transformer_model.encode(
                src, src_padding_mask
            )  # [1, src_len, emb_size]
            memory = memory.to(device)
            memory_key_padding_mask = src_padding_mask.to(memory.device)  # [1, src_len]

            # --- Initialize Beams ---
            initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
                sos_idx
            )  # [1, 1]
            initial_beam_score = torch.zeros(1, dtype=torch.float, device=device)  # [1]
            active_beams = [(initial_beam_seq, initial_beam_score)]
            finished_beams = []

            # --- Decoding Loop ---
            for step in range(max_len - 1):
                if not active_beams:
                    break

                potential_next_beams = []
                for current_seq, current_score in active_beams:
                    # Check if the beam already ended
                    if current_seq[0, -1].item() == eos_idx:
                        # If already finished, add directly to finished beams and skip expansion
                        finished_beams.append((current_seq, current_score))
                        continue

                    # Prepare inputs for the decoder
                    tgt_input = current_seq  # [1, current_len]
                    tgt_seq_len = tgt_input.shape[1]
                    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
                        device
                    )  # [curr_len, curr_len]
                    # No padding in target during generation yet
                    tgt_padding_mask = torch.zeros(
                        tgt_input.shape, dtype=torch.bool, device=device
                    )  # [1, curr_len]

                    # Decode one step
                    decoder_output = transformer_model.decode(
                        tgt=tgt_input,
                        memory=memory,
                        tgt_mask=tgt_mask,
                        tgt_padding_mask=tgt_padding_mask,
                        memory_key_padding_mask=memory_key_padding_mask,
                    )  # [1, curr_len, emb_size]

                    # Get logits for the *next* token prediction
                    next_token_logits = transformer_model.generator(
                        decoder_output[
                            :, -1, :
                        ]  # Use output corresponding to the last input token
                    )  # [1, tgt_vocab_size]

                    # Calculate log probabilities and add current beam score
                    log_probs = F.log_softmax(
                        next_token_logits, dim=-1
                    )  # [1, tgt_vocab_size]
                    combined_scores = (
                        log_probs + current_score
                    )  # Add score of the current path

                    # Find top k candidates for the *next* step
                    topk_log_probs, topk_indices = torch.topk(
                        combined_scores, beam_width, dim=-1
                    )  # [1, beam_width], [1, beam_width]

                    # Expand potential beams
                    for i in range(beam_width):
                        next_token_id = topk_indices[0, i].item()
                        # Score is the cumulative log probability of the new sequence
                        next_score = topk_log_probs[0, i].reshape(
                            1
                        )  # Keep as tensor [1]
                        next_token_tensor = torch.tensor(
                            [[next_token_id]], dtype=torch.long, device=device
                        )  # [1, 1]
                        new_seq = torch.cat(
                            [current_seq, next_token_tensor], dim=1
                        )  # [1, current_len + 1]
                        potential_next_beams.append((new_seq, next_score))

                # --- Prune Beams ---
                # Sort all potential next beams by score
                potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)

                # Select the top `beam_width` beams for the next iteration
                active_beams = []
                temp_finished_beams = []  # Collect beams finished in *this* step
                for seq, score in potential_next_beams:
                    if (
                        len(active_beams) >= beam_width
                        and len(temp_finished_beams) >= beam_width
                    ):
                        break  # Optimization: Stop if we have enough active and finished candidates

                    is_finished = seq[0, -1].item() == eos_idx
                    if is_finished:
                        # Add to temporary finished list for this step
                        if len(temp_finished_beams) < beam_width:
                            temp_finished_beams.append((seq, score))
                    elif len(active_beams) < beam_width:
                        # Add to active beams for next step
                        active_beams.append((seq, score))

                # Add the newly finished beams to the main finished list
                finished_beams.extend(temp_finished_beams)
                # Optional: Prune finished_beams if it grows too large (e.g., keep top 2*beam_width)
                finished_beams.sort(key=lambda x: x[1].item(), reverse=True)
                finished_beams = finished_beams[
                    : beam_width * 2
                ]  # Keep a reasonable number

            # --- Final Selection ---
            # Add any remaining active beams (which didn't finish) to the finished list
            finished_beams.extend(active_beams)

            # Apply length penalty and sort
            def get_score_with_penalty(beam_tuple):
                seq, score = beam_tuple
                seq_len = seq.shape[1]
                # Avoid division by zero or negative exponent issues
                if length_penalty <= 0.0 or seq_len <= 1:
                    return score.item()
                else:
                    # Length penalty calculation
                    penalty = (
                        (5.0 + float(seq_len)) / 6.0
                    ) ** length_penalty  # Common formula
                    return score.item() / penalty
                    # Alternative simpler penalty:
                    # return score.item() / (float(seq_len) ** length_penalty)

            finished_beams.sort(
                key=get_score_with_penalty, reverse=True
            )  # Higher score is better

            # Return the top n_best sequences (excluding the initial SOS token)
            top_sequences = [
                seq[:, 1:]
                for seq, score in finished_beams[:n_best]
                if seq.shape[1] > 1  # Ensure seq not just SOS
            ]  # seq shape [1, len] -> [1, len-1]
            return top_sequences

    except RuntimeError as e:
        logging.error(f"Runtime error during beam search decode: {e}", exc_info=True)
        if "CUDA out of memory" in str(e) and device.type == "cuda":
            gc.collect()
            torch.cuda.empty_cache()
        return []  # Return empty list on error
    except Exception as e:
        logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
        return []


# --- Translation Function (Locally defined) ---
def translate(
    model: pl.LightningModule,
    src_sentence: str,
    smiles_tokenizer: Tokenizer,
    iupac_tokenizer: Tokenizer,
    device: torch.device,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    pad_idx: int,
    beam_width: int = 5,
    n_best: int = 5,
    length_penalty: float = 0.6,
) -> list[str]:
    """
    Translates a single SMILES string using beam search.
    (Ensures this code is self-contained within app.py or correctly imported)
    """
    model.eval()  # Ensure model is in eval mode
    translations = []
    n_best = min(n_best, beam_width)  # Can't return more than beam width

    # --- Tokenize Source ---
    try:
        # Ensure tokenizer has truncation/padding configured if needed, or handle manually
        smiles_tokenizer.enable_truncation(max_length=max_len)
        src_encoded = smiles_tokenizer.encode(src_sentence)
        if not src_encoded or not src_encoded.ids:
            logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
            return ["[Encoding Error]"] * n_best
        # Use the truncated IDs directly
        src_ids = src_encoded.ids
        # Note: max_len here applies to source *tokenizer*, generation length is separate
    except Exception as e:
        logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
        return ["[Encoding Error]"] * n_best

    # --- Prepare Input Tensor and Mask ---
    src = (
        torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
    )  # [1, src_len]
    # Create padding mask (True where it's a pad token, should be all False here)
    src_padding_mask = (src == pad_idx).to(device)  # [1, src_len]

    # --- Perform Beam Search Decoding ---
    # Calls the beam_search_decode function defined *above in this file*
    # Note: max_len for generation should come from config if it dictates output length
    generation_max_len = config.get(
        "max_len", 256
    )  # Use config max_len for output limit
    tgt_tokens_list = beam_search_decode(
        model=model,
        src=src,
        src_padding_mask=src_padding_mask,
        max_len=generation_max_len,  # Use generation limit
        sos_idx=sos_idx,
        eos_idx=eos_idx,
        pad_idx=pad_idx,
        device=device,
        beam_width=beam_width,
        n_best=n_best,
        length_penalty=length_penalty,
    )  # Returns list of tensors

    # --- Decode Generated Tokens ---
    if not tgt_tokens_list:
        logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
        # Provide n_best error messages
        return ["[Decoding Error - Empty Output]"] * n_best

    for i, tgt_tokens_tensor in enumerate(tgt_tokens_list):
        if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0:
            tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
            try:
                # Decode using the target tokenizer, skipping special tokens
                translation = iupac_tokenizer.decode(
                    tgt_tokens, skip_special_tokens=True
                )
                translations.append(translation)
            except Exception as e:
                logging.error(
                    f"Error decoding target tokens {tgt_tokens} for beam {i}: {e}",
                    exc_info=True,
                )
                translations.append("[Decoding Error]")
        else:
            logging.warning(
                f"Beam {i} result was empty or None for SMILES: {src_sentence}"
            )
            translations.append("[Decoding Error - Empty Tensor]")

    # Pad with error messages if fewer than n_best results were generated
    while len(translations) < n_best:
        translations.append("[Decoding Error - Fewer Results]")

    return translations


# --- Model/Tokenizer Loading Function ---
def load_model_and_tokenizers():
    """Loads tokenizers, config, and model from Hugging Face Hub."""
    global model, smiles_tokenizer, iupac_tokenizer, device, config
    if model is not None:  # Already loaded
        logging.info("Model and tokenizers already loaded.")
        return

    logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
    try:
        # Determine device - Use CPU for Gradio Spaces unless GPU is explicitly available and desired
        # For simplicity and broader compatibility on free tier Spaces, CPU is safer.
        if torch.cuda.is_available():
            logging.warning(
                "CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
            )
            device = torch.device("cpu")
            # Uncomment below and comment above line to try using GPU if available
            # device = torch.device("cuda")
            # logging.info("CUDA available, using GPU.")
        else:
            device = torch.device("cpu")
            logging.info("CUDA not available, using CPU.")

        # Download files from HF Hub
        logging.info("Downloading files from Hugging Face Hub...")
        try:
            # Use cache directory for Spaces persistence if possible
            cache_dir = os.environ.get(
                "GRADIO_CACHE", "./hf_cache"
            )  # Gradio sets cache dir
            os.makedirs(cache_dir, exist_ok=True)
            logging.info(f"Using cache directory: {cache_dir}")

            checkpoint_path = hf_hub_download(
                repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
            )
            smiles_tokenizer_path = hf_hub_download(
                repo_id=MODEL_REPO_ID,
                filename=SMILES_TOKENIZER_FILENAME,
                cache_dir=cache_dir,
            )
            iupac_tokenizer_path = hf_hub_download(
                repo_id=MODEL_REPO_ID,
                filename=IUPAC_TOKENIZER_FILENAME,
                cache_dir=cache_dir,
            )
            config_path = hf_hub_download(
                repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
            )
            logging.info("Files downloaded successfully.")
        except Exception as e:
            logging.error(
                f"Failed to download files from {MODEL_REPO_ID}. Check filenames ({CHECKPOINT_FILENAME}, {SMILES_TOKENIZER_FILENAME}, etc.) and repo status. Error: {e}",
                exc_info=True,
            )
            raise gr.Error(
                f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}"
            )

        # Load config
        logging.info("Loading configuration...")
        try:
            with open(config_path, "r") as f:
                config = json.load(f)
            logging.info("Configuration loaded.")
            # --- Validate essential config keys ---
            # Use hparams logged during training if available, map them carefully
            # These keys are based on SmilesIupacLitModule and Seq2SeqTransformer init args
            # Mappings might be needed if keys in config.json differ from these exact names
            required_keys = [
                # Need vocab sizes used during *training* for loading
                "actual_src_vocab_size",  # Assuming this was saved in hparams
                "actual_tgt_vocab_size",  # Assuming this was saved in hparams
                # Model architecture params
                "emb_size",
                "nhead",
                "ffn_hid_dim",
                "num_encoder_layers",
                "num_decoder_layers",
                "dropout",
                "max_len",  # Needed for generation limit and tokenizer setting
                # Special token IDs needed for generation
                # Assuming standard names, adjust if your config uses different keys
                "pad_token_id",  # Often 0
                "bos_token_id",  # Often 1 (used as SOS)
                "eos_token_id",  # Often 2
            ]
            # Remap keys if necessary (e.g., if config.json uses 'src_vocab_size' instead of 'actual_src_vocab_size')
            config_key_mapping = {
                "actual_src_vocab_size": config.get(
                    "actual_src_vocab_size", config.get("src_vocab_size")
                ),
                "actual_tgt_vocab_size": config.get(
                    "actual_tgt_vocab_size", config.get("tgt_vocab_size")
                ),
                "emb_size": config.get("emb_size"),
                "nhead": config.get("nhead"),
                "ffn_hid_dim": config.get("ffn_hid_dim"),
                "num_encoder_layers": config.get("num_encoder_layers"),
                "num_decoder_layers": config.get("num_decoder_layers"),
                "dropout": config.get("dropout"),
                "max_len": config.get("max_len"),
                "pad_token_id": config.get(
                    "pad_token_id"
                ),  # Use default if missing? Risky.
                "bos_token_id": config.get(
                    "bos_token_id"
                ),  # Use default if missing? Risky.
                "eos_token_id": config.get(
                    "eos_token_id"
                ),  # Use default if missing? Risky.
            }
            # Update config with potentially remapped values
            config.update(config_key_mapping)

            missing_keys = [key for key in required_keys if config.get(key) is None]
            if missing_keys:
                # Try to load defaults for token IDs if absolutely necessary, but warn heavily
                defaults_used = []
                # Re-check missing keys after attempting defaults
                missing_keys = [key for key in required_keys if config.get(key) is None]
                if missing_keys:
                    raise ValueError(
                        f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
                        f"Ensure these were saved in the hyperparameters during training."
                    )
                else:
                    logging.warning(
                        f"Config file was missing keys, used defaults for: {defaults_used}. This might be incorrect!"
                    )

            # Log the final config values being used
            logging.info(
                f"Using config values: src_vocab={config['actual_src_vocab_size']}, tgt_vocab={config['actual_tgt_vocab_size']}, "
                f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
                f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
            )

        except FileNotFoundError:
            logging.error(
                f"Config file not found locally after download attempt: {config_path}"
            )
            raise gr.Error(
                f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo."
            )
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from config file {config_path}: {e}")
            raise gr.Error(
                f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}"
            )
        except ValueError as e:  # Catch our custom validation error
            logging.error(f"Config validation error: {e}")
            raise gr.Error(f"Config Error: {e}")
        except Exception as e:  # Catch other potential errors during config processing
            logging.error(
                f"Unexpected error loading or validating config: {e}", exc_info=True
            )
            raise gr.Error(
                f"Config Error: Unexpected error processing config. Check logs. Error: {e}"
            )

        # Load tokenizers
        logging.info("Loading tokenizers...")
        try:
            smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
            iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
            logging.info("Tokenizers loaded.")

            # --- Validate Tokenizer Special Tokens Against Config ---
            pad_token = "<pad>"
            sos_token = "<sos>"
            eos_token = "<eos>"
            unk_token = "<unk>"

            issues = []
            if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
                issues.append(
                    f"SMILES PAD ID mismatch (tokenizer={smiles_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})"
                )
            if smiles_tokenizer.token_to_id(unk_token) is None:
                issues.append("SMILES UNK token not found")

            if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
                issues.append(
                    f"IUPAC PAD ID mismatch (tokenizer={iupac_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})"
                )
            if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
                issues.append(
                    f"IUPAC SOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(sos_token)}, config={config['bos_token_id']})"
                )
            if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
                issues.append(
                    f"IUPAC EOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(eos_token)}, config={config['eos_token_id']})"
                )
            if iupac_tokenizer.token_to_id(unk_token) is None:
                issues.append("IUPAC UNK token not found")

            if issues:
                logging.warning(
                    "Tokenizer validation issues detected: " + "; ".join(issues)
                )
                # Decide if this is fatal or just a warning
                # raise gr.Error("Tokenizer Error: Special token IDs mismatch config. Check tokenizers and config.json.") # Make it fatal if IDs must match

        except Exception as e:
            logging.error(
                f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}",
                exc_info=True,
            )
            raise gr.Error(
                f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}"
            )

        # Load model
        logging.info("Loading model from checkpoint...")
        try:
            # Load the LightningModule state dict
            # Use the actual vocab sizes and hparams from the loaded config
            model = SmilesIupacLitModule.load_from_checkpoint(
                checkpoint_path,
                # Pass necessary __init__ args that might not be in saved hparams automatically
                # Ensure these keys exist in your loaded 'config' dict after validation/mapping
                src_vocab_size=config["actual_src_vocab_size"],
                tgt_vocab_size=config["actual_tgt_vocab_size"],
                hparams_dict=config,  # Pass the loaded config as hparams
                map_location=device,  # Map model to the chosen device (CPU or CUDA)
                strict=False,  # Be less strict about matching keys, useful for PTL versions or minor changes
                # REMOVED invalid argument: device="cpu",
            )

            # Ensure model is on the correct device, in eval mode, and frozen
            model.to(device)
            model.eval()
            model.freeze()  # Disables gradient calculations
            logging.info(
                f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
            )

        except FileNotFoundError:
            logging.error(
                f"Checkpoint file not found locally after download attempt: {checkpoint_path}"
            )
            raise gr.Error(
                f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
            )
        except Exception as e:
            logging.error(
                f"Error loading model from checkpoint {checkpoint_path}: {e}",
                exc_info=True,
            )
            # Check for common errors
            if "size mismatch" in str(e):
                error_detail = (
                    f"Potential size mismatch. Check if vocab sizes in config.json ({config.get('actual_src_vocab_size')}, "
                    f"{config.get('actual_tgt_vocab_size')}) match the loaded checkpoint's embedding layers."
                )
                logging.error(error_detail)
                raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
            elif "memory" in str(e).lower():
                logging.warning("Potential Out-of-Memory error during model loading.")
                gc.collect()
                if device.type == "cuda":
                    torch.cuda.empty_cache()
                raise gr.Error(
                    f"Model Error: Out of memory loading model. Check Space resources. Error: {e}"
                )
            else:
                raise gr.Error(
                    f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}"
                )

    except gr.Error:  # Re-raise Gradio errors to be displayed
        raise
    except Exception as e:  # Catch any other unexpected errors
        logging.error(
            f"Unexpected error during model/tokenizer loading: {e}", exc_info=True
        )
        raise gr.Error(
            f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}"
        )


# --- Inference Function for Gradio ---
def predict_iupac(smiles_string, beam_width_str, n_best_str):
    """
    Performs SMILES to IUPAC translation using the loaded model and beam search.
    Takes string inputs from Gradio sliders/inputs and converts them.
    """
    global model, smiles_tokenizer, iupac_tokenizer, device, config

    if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
        error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
        logging.error(error_msg)
        # Try to determine n_best for error output formatting
        try:
            n_best_int = int(n_best_str)
        except:
            n_best_int = 1
        return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])

    if not smiles_string or not smiles_string.strip():
        error_msg = "Error: Please enter a valid SMILES string."
        try:
            n_best_int = int(n_best_str)
        except:
            n_best_int = 1
        return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])

    smiles_input = smiles_string.strip()

    # --- Safely parse numerical inputs ---
    try:
        beam_width = int(beam_width_str)
        n_best = int(n_best_str)
        if beam_width < 1 or n_best < 1 or n_best > beam_width:
            raise ValueError(
                "Beam width and n_best must be >= 1, and n_best <= beam width."
            )
    except ValueError as e:
        error_msg = f"Error: Invalid input parameter ({e}). Please check beam width, n_best, and length penalty values."
        logging.error(error_msg)
        # Cannot determine n_best if its input was invalid, default to 1 error line
        return f"1. {error_msg}"

    try:
        # --- Call the core translation logic ---
        # Retrieve necessary IDs from the loaded config
        sos_idx = config["bos_token_id"]
        eos_idx = config["eos_token_id"]
        pad_idx = config["pad_token_id"]
        gen_max_len = config["max_len"]  # Max length for generation

        predicted_names = translate(
            model=model,
            src_sentence=smiles_input,
            smiles_tokenizer=smiles_tokenizer,
            iupac_tokenizer=iupac_tokenizer,
            device=device,
            max_len=gen_max_len,  # Pass generation length limit
            sos_idx=sos_idx,
            eos_idx=eos_idx,
            pad_idx=pad_idx,
            beam_width=beam_width,
            n_best=n_best,
            length_penalty=0.0,
        )
        logging.info(f"Predictions returned: {predicted_names}")

        # --- Format Output ---
        if not predicted_names:
            output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated (beam search might have failed)."
        else:
            # Ensure we only display up to n_best results, even if translate returned more/fewer due to errors
            display_names = predicted_names[:n_best]
            output_text = (
                f"Input SMILES: {smiles_input}\n\n"
                f"Top {len(display_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
            )
            output_text += "\n".join(
                [f"{i + 1}. {name}" for i, name in enumerate(display_names)]
            )
            # Add a note if fewer results than requested were generated
            if len(display_names) < n_best:
                output_text += f"\n\nNote: Only {len(display_names)} result(s) generated successfully."

        return output_text

    except RuntimeError as e:
        logging.error(f"Runtime error during translation: {e}", exc_info=True)
        error_msg = f"Runtime Error during translation: {e}"
        if "memory" in str(e).lower():
            gc.collect()
            if device.type == "cuda":
                torch.cuda.empty_cache()
            error_msg += " (Potential OOM - try reducing beam width or input length)"
        # Return n_best error messages
        return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])

    except Exception as e:
        logging.error(f"Unexpected error during translation: {e}", exc_info=True)
        error_msg = f"Unexpected Error during translation: {e}"
        return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])


# --- Load Model on App Start ---
# Wrap in try/except to prevent app from crashing completely if loading fails
# The error should be caught and displayed by Gradio via gr.Error raised in the function.
try:
    load_model_and_tokenizers()
except gr.Error as ge:
    logging.error(f"Gradio Initialization Error: {ge}")
    # Gradio handles displaying gr.Error, but we log it too.
    # We might want to display a placeholder UI or message if loading fails critically.
    pass  # Allow Gradio to potentially start with an error message
except Exception as e:
    # Catch any non-Gradio errors during the initial load sequence
    logging.error(
        f"Critical error during initial model loading sequence: {e}", exc_info=True
    )
    # Optionally raise gr.Error here too, although it might be too late if Gradio hasn't fully initialized.
    # raise gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")


# --- Create Gradio Interface ---
title = "SMILES to IUPAC Name Translator"
description = f"""
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
Translation uses beam search decoding. Adjust parameters below.
**Note:** Model loaded on **{str(device).upper()}**. Performance may vary. Check `config.json` in the repo for model details.
"""

# Define examples using the input types expected by the interface
examples = [
    ["CCO", 5, 3, 0.6],  # Ethanol
    ["C1=CC=CC=C1", 5, 3, 0.6],  # Benzene
    ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6],  # Aspirin
    ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6],  # Ibuprofen
    # Very complex example - might take time or fail on CPU/low memory
    # ["CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=C(C(=N4)C5=CC=CC=C5)C", 8, 1, 0.7], # Gleevec (Imatinib) - simplified SMILES structure
    ["INVALID_SMILES", 3, 1, 0.6],  # Example of invalid input
]

# Ensure input components match the `predict_iupac` function signature order and types
smiles_input = gr.Textbox(
    label="SMILES String",
    placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
    lines=1,
)
# Use number inputs for sliders if direct type casting is desired, but sliders often return float/int anyway
beam_width_input = gr.Slider(
    minimum=1,
    maximum=10,
    value=5,
    step=1,
    label="Beam Width (k)",
    info="Number of sequences kept at each step (higher = more exploration, slower). Affects memory usage.",
)
n_best_input = gr.Slider(
    minimum=1,
    maximum=10,
    value=3,
    step=1,
    label="Number of Results (n_best)",
    info="How many top sequences to return (must be <= Beam Width).",
)

output_text = gr.Textbox(
    label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
)

# Create the interface instance
iface = gr.Interface(
    fn=predict_iupac,  # The function to call
    inputs=[  # List of input components
        smiles_input,
        beam_width_input,
        n_best_input,
    ],
    outputs=output_text,  # Output component
    title=title,
    description=description,
    examples=examples,  # Examples to populate the interface
    allow_flagging="never",  # Disable flagging
    theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),  # Optional theme
    article="""
    **Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
    Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations.
    Beam search parameters (width, penalty) significantly impact results and performance.
    """,
    # Optional: Add live=True for real-time updates as sliders change (can be slow/resource intensive)
    # live=False,
)

# --- Launch the App ---
if __name__ == "__main__":
    iface.launch()