File size: 34,655 Bytes
aaafea4
 
 
f071a4c
9bd549c
aaafea4
 
 
 
 
9bd549c
f071a4c
fa90b40
f071a4c
 
aaafea4
 
fe32912
 
 
9bd549c
aaafea4
 
9bd549c
aaafea4
 
 
59543a5
 
 
aaafea4
028e0b0
aaafea4
f071a4c
aaafea4
b3f7b39
aaafea4
 
59543a5
028e0b0
 
59543a5
b3f7b39
 
 
aaafea4
59543a5
 
 
b3f7b39
 
 
aaafea4
 
028e0b0
aaafea4
 
 
 
 
fe32912
f071a4c
9bd549c
aaafea4
 
 
 
 
 
 
9bd549c
aaafea4
9bd549c
f071a4c
aaafea4
f071a4c
 
aaafea4
 
 
f071a4c
aaafea4
f071a4c
aaafea4
f071a4c
aaafea4
f071a4c
9bd549c
f071a4c
 
 
9bd549c
b3f7b39
 
 
 
 
f071a4c
 
 
 
 
b3f7b39
 
 
f071a4c
b3f7b39
 
f071a4c
9bd549c
 
aaafea4
 
f071a4c
aaafea4
 
9bd549c
fe32912
59543a5
 
f071a4c
aaafea4
9bd549c
b3f7b39
aaafea4
fe32912
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaafea4
 
 
 
 
 
b3f7b39
aaafea4
 
 
f071a4c
 
 
 
 
aaafea4
f071a4c
aaafea4
f071a4c
aaafea4
 
 
f071a4c
aaafea4
 
 
f071a4c
028e0b0
aaafea4
028e0b0
f071a4c
aaafea4
 
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaafea4
 
f071a4c
 
 
 
 
aaafea4
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
aaafea4
fe32912
f071a4c
aaafea4
 
b3f7b39
f071a4c
aaafea4
 
 
 
 
f071a4c
 
 
 
 
 
 
aaafea4
9bd549c
aaafea4
f071a4c
 
 
 
aaafea4
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
028e0b0
b3f7b39
aaafea4
f071a4c
 
aaafea4
 
 
 
59543a5
aaafea4
 
f071a4c
028e0b0
aaafea4
f071a4c
 
9bd549c
f071a4c
 
 
aaafea4
 
 
 
 
 
f071a4c
 
 
 
 
b3f7b39
aaafea4
9bd549c
f071a4c
aaafea4
 
 
 
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaafea4
 
9bd549c
f071a4c
 
 
aaafea4
028e0b0
aaafea4
 
b3f7b39
f071a4c
aaafea4
 
9bd549c
f071a4c
b3f7b39
f071a4c
b3f7b39
 
 
 
f071a4c
 
 
 
b3f7b39
 
 
 
f071a4c
b3f7b39
f071a4c
aaafea4
b3f7b39
f071a4c
b3f7b39
 
f071a4c
aaafea4
 
f071a4c
d11f5bc
f071a4c
aaafea4
f071a4c
aaafea4
 
 
f071a4c
aaafea4
b3f7b39
028e0b0
f071a4c
aaafea4
 
f071a4c
aaafea4
b3f7b39
028e0b0
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaafea4
f071a4c
fe32912
 
 
b3f7b39
f071a4c
 
028e0b0
f071a4c
aaafea4
b3f7b39
aaafea4
 
 
b3f7b39
028e0b0
 
 
f071a4c
 
 
 
aaafea4
f071a4c
aaafea4
028e0b0
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
aaafea4
f071a4c
 
 
 
 
 
 
 
aaafea4
b3f7b39
 
f071a4c
 
 
aaafea4
b3f7b39
f071a4c
aaafea4
 
028e0b0
aaafea4
 
028e0b0
f071a4c
 
 
 
aaafea4
b3f7b39
f071a4c
aaafea4
 
f071a4c
 
aaafea4
f071a4c
 
 
 
aaafea4
b3f7b39
f071a4c
 
fae0efa
 
3369de4
fae0efa
f071a4c
fae0efa
 
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fae0efa
3369de4
f071a4c
 
fae0efa
f071a4c
 
fae0efa
3369de4
fae0efa
aaafea4
f071a4c
3369de4
 
f071a4c
3369de4
f071a4c
3369de4
 
f071a4c
3369de4
 
f071a4c
3369de4
 
fae0efa
f071a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
028e0b0
aaafea4
3369de4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
# app.py
import gradio as gr
import torch
import torch.nn.functional as F # Needed for log_softmax in beam search
import pytorch_lightning as pl
import os
import json
import logging
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import gc
from rdkit.Chem import CanonSmiles, MolFromSmiles # Added MolFromSmiles for validation
import spaces
import heapq # For beam search priority queue
import math # For log probabilities

# --- Configuration ---
MODEL_REPO_ID = (
    "AdrianM0/smiles-to-iupac-translator"  # <-- Make sure this is your repo ID
)
CHECKPOINT_FILENAME = "last.ckpt"
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
CONFIG_FILENAME = "config.json"
# --- End Configuration ---

# --- Logging ---
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
try:
    # Ensure enhanced_trainer.py is present in the repository root
    from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask

    logging.info("Successfully imported from enhanced_trainer.py.")
except ImportError as e:
    logging.error(
        f"Failed to import helper code from enhanced_trainer.py: {e}. "
        f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
    )
    raise gr.Error(
        f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
    )
except Exception as e:
    logging.error(
        f"An unexpected error occurred during helper code import: {e}", exc_info=True
    )
    raise gr.Error(
        f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
    )

# --- Global Variables (Load Model Once) ---
model: pl.LightningModule | None = None
smiles_tokenizer: Tokenizer | None = None
iupac_tokenizer: Tokenizer | None = None
device: torch.device | None = None
config: dict | None = None


# --- Greedy Decoding Logic (Unchanged) ---
def greedy_decode(
    model: pl.LightningModule,
    src: torch.Tensor,
    src_padding_mask: torch.Tensor,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    device: torch.device,
) -> torch.Tensor:
    """
    Performs greedy decoding using the LightningModule's model.
    Returns a tensor of shape [1, sequence_length].
    """
    model.eval()
    transformer_model = model.model

    try:
        with torch.no_grad():
            memory = transformer_model.encode(src, src_padding_mask)
            memory = memory.to(device)
            memory_key_padding_mask = src_padding_mask.to(memory.device)

            ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx)

            for _ in range(max_len - 1):
                tgt_seq_len = ys.shape[1]
                tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device)
                tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device)

                decoder_output = transformer_model.decode(
                    tgt=ys,
                    memory=memory,
                    tgt_mask=tgt_mask,
                    tgt_padding_mask=tgt_padding_mask,
                    memory_key_padding_mask=memory_key_padding_mask,
                )

                next_token_logits = transformer_model.generator(decoder_output[:, -1, :])
                next_word_id = torch.argmax(next_token_logits, dim=1).item()

                ys = torch.cat(
                    [
                        ys,
                        torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id),
                    ],
                    dim=1,
                )

                if next_word_id == eos_idx:
                    break

            return ys[:, 1:] # Exclude SOS

    except RuntimeError as e:
        logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
        if "CUDA out of memory" in str(e) and device.type == "cuda":
            gc.collect()
            torch.cuda.empty_cache()
        return torch.empty((1, 0), dtype=torch.long, device=device)
    except Exception as e:
        logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
        return torch.empty((1, 0), dtype=torch.long, device=device)


# --- Beam Search Decoding Logic ---
def beam_search_decode(
    model: pl.LightningModule,
    src: torch.Tensor,
    src_padding_mask: torch.Tensor,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    pad_idx: int, # Needed for padding shorter beams if batching
    device: torch.device,
    beam_width: int,
    num_return_sequences: int = 1,
    length_penalty_alpha: float = 0.6, # Add length penalty
) -> list[tuple[torch.Tensor, float]]:
    """
    Performs beam search decoding.
    Returns a list of tuples: (sequence_tensor [1, seq_len], score)
    """
    model.eval()
    transformer_model = model.model
    num_return_sequences = min(beam_width, num_return_sequences) # Cannot return more than beam width

    try:
        with torch.no_grad():
            # --- Encode Source (Once) ---
            memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size]
            memory = memory.to(device)
            memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]

            # --- Initialize Beams ---
            # Each beam: (sequence_tensor [1, current_len], score (log_prob))
            initial_beam = (
                torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx),
                0.0, # Initial score (log probability)
            )
            beams = [initial_beam]
            finished_hypotheses = [] # Store finished sequences: (score, sequence_tensor)

            # --- Decoding Loop ---
            for step in range(max_len - 1):
                if not beams: # Stop if no active beams left
                    break

                # Use a min-heap to keep track of candidates for the *next* step
                # Store (-score, next_token_id, beam_index) - use negative score for max-heap behavior
                candidates = []

                # Process current beams (can be batched for efficiency, but simpler loop shown here)
                # For batching: stack ys, expand memory, create masks for the batch
                for beam_idx, (current_seq, current_score) in enumerate(beams):
                    if current_seq[0, -1].item() == eos_idx: # Beam already finished
                        # Add length penalty before storing
                        penalty = ((current_seq.shape[1]) ** length_penalty_alpha)
                        final_score = current_score / penalty if penalty > 0 else current_score
                        heapq.heappush(finished_hypotheses, (final_score, current_seq))
                        # Prune finished hypotheses if we have enough
                        while len(finished_hypotheses) > beam_width:
                             heapq.heappop(finished_hypotheses) # Remove lowest score
                        continue # Don't expand finished beams

                    # --- Prepare input for this beam ---
                    ys = current_seq # [1, current_len]
                    tgt_seq_len = ys.shape[1]
                    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device)
                    tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device)

                    # --- Decode one step ---
                    # Note: memory and memory_key_padding_mask are reused
                    decoder_output = transformer_model.decode(
                        tgt=ys,
                        memory=memory, # Needs expansion if batching beams
                        tgt_mask=tgt_mask,
                        tgt_padding_mask=tgt_padding_mask,
                        memory_key_padding_mask=memory_key_padding_mask, # Needs expansion if batching
                    ) # [1, current_len, emb_size]

                    # Get logits for the *next* token
                    next_token_logits = transformer_model.generator(
                        decoder_output[:, -1, :]
                    ) # [1, tgt_vocab_size]

                    # Calculate log probabilities
                    log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size]

                    # Get top K candidates for *this* beam
                    # Adding current_score makes it the total path score
                    top_k_log_probs, top_k_indices = torch.topk(log_probs + current_score, beam_width, dim=1)

                    # Add candidates to the list for selection across all beams
                    for i in range(beam_width):
                        token_id = top_k_indices[0, i].item()
                        score = top_k_log_probs[0, i].item()
                        # Store (-score, token_id, beam_idx) for heap
                        heapq.heappush(candidates, (-score, token_id, beam_idx))
                        # Prune candidates heap if it exceeds beam_width * beam_width (can optimize)
                        # A simpler pruning: keep only top N overall candidates later

                # --- Select Top K Beams for Next Step ---
                new_beams = []
                # Ensure we don't exceed beam_width overall candidates
                num_candidates_to_consider = min(len(candidates), beam_width * len(beams)) # Rough upper bound
                
                # Use heap to efficiently get top k candidates overall
                top_candidates = heapq.nsmallest(beam_width, candidates) # Get k smallest (-score) -> largest score

                added_sequences = set() # Prevent duplicate sequences if paths converge

                for neg_score, token_id, beam_idx in top_candidates:
                    original_seq, _ = beams[beam_idx]
                    new_seq = torch.cat(
                        [
                            original_seq,
                            torch.ones(1, 1, dtype=torch.long, device=device).fill_(token_id),
                        ],
                        dim=1,
                    ) # [1, current_len + 1]

                    # Avoid adding duplicates (optional, but good practice)
                    seq_tuple = tuple(new_seq.flatten().tolist())
                    if seq_tuple not in added_sequences:
                       new_beams.append((new_seq, -neg_score)) # Store positive score
                       added_sequences.add(seq_tuple)

                beams = new_beams # Update active beams

                # Early stopping: If top beam is finished and we have enough results
                if finished_hypotheses:
                    # Check if the best possible score from active beams is worse than the worst finished beam
                    best_active_score = -heapq.nsmallest(1, candidates)[0][0] if candidates else -float('inf')
                    worst_finished_score = finished_hypotheses[0][0] # Smallest score in min-heap
                    if len(finished_hypotheses) >= num_return_sequences and best_active_score < worst_finished_score:
                         logging.debug(f"Beam search early stopping at step {step}")
                         break


            # --- Final Selection ---
            # Add any remaining active beams to finished list (if they didn't end with EOS)
            for seq, score in beams:
                 if seq[0, -1].item() != eos_idx:
                    penalty = ((seq.shape[1]) ** length_penalty_alpha)
                    final_score = score / penalty if penalty > 0 else score
                    heapq.heappush(finished_hypotheses, (final_score, seq))
                    while len(finished_hypotheses) > beam_width:
                        heapq.heappop(finished_hypotheses)

            # Sort finished hypotheses by score (descending) and select top N
            # heapq is min-heap, so nlargest gets the best scores
            top_hypotheses = heapq.nlargest(num_return_sequences, finished_hypotheses)

            # Return list of (sequence_tensor [1, seq_len], score) excluding SOS
            return [(seq[:, 1:], score) for score, seq in top_hypotheses]

    except RuntimeError as e:
        logging.error(f"Runtime error during beam search: {e}", exc_info=True)
        if "CUDA out of memory" in str(e) and device.type == "cuda":
            gc.collect()
            torch.cuda.empty_cache()
        return [] # Return empty list on error
    except Exception as e:
        logging.error(f"Unexpected error during beam search: {e}", exc_info=True)
        return []


# --- Translation Function (Handles both Greedy and Beam Search) ---
def translate(
    model: pl.LightningModule,
    src_sentence: str,
    smiles_tokenizer: Tokenizer,
    iupac_tokenizer: Tokenizer,
    device: torch.device,
    max_len: int,
    sos_idx: int,
    eos_idx: int,
    pad_idx: int,
    decoding_strategy: str = "Greedy",
    beam_width: int = 5,
    num_return_sequences: int = 1,
    length_penalty_alpha: float = 0.6,
) -> list[tuple[str, float]]: # Returns list of (translation_string, score)
    """
    Translates a single SMILES string using the specified decoding strategy.
    """
    model.eval()

    # --- Tokenize Source ---
    try:
        smiles_tokenizer.enable_truncation(max_length=max_len)
        src_encoded = smiles_tokenizer.encode(src_sentence)
        if not src_encoded or not src_encoded.ids:
            logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
            return [("[Encoding Error]", 0.0)]
        src_ids = src_encoded.ids
    except Exception as e:
        logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
        return [("[Encoding Error]", 0.0)]

    # --- Prepare Input Tensor and Mask ---
    src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len]
    src_padding_mask = (src == pad_idx).to(device) # [1, src_len]

    # --- Perform Decoding ---
    generation_max_len = config.get("max_len", 256)
    results = [] # List to store (tensor, score) tuples

    if decoding_strategy == "Greedy":
        tgt_tokens_tensor = greedy_decode(
            model=model,
            src=src,
            src_padding_mask=src_padding_mask,
            max_len=generation_max_len,
            sos_idx=sos_idx,
            eos_idx=eos_idx,
            device=device,
        ) # Returns tensor [1, generated_len]
        if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0:
             results = [(tgt_tokens_tensor, 0.0)] # Assign dummy score 0.0 for greedy
        else:
             logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}")
             return [("[Decoding Error - Empty Output]", 0.0)]

    elif decoding_strategy == "Beam Search":
        results = beam_search_decode(
            model=model,
            src=src,
            src_padding_mask=src_padding_mask,
            max_len=generation_max_len,
            sos_idx=sos_idx,
            eos_idx=eos_idx,
            pad_idx=pad_idx,
            device=device,
            beam_width=beam_width,
            num_return_sequences=num_return_sequences,
            length_penalty_alpha=length_penalty_alpha,
        ) # Returns list of (tensor, score)
        if not results:
             logging.warning(f"Beam search returned no results for SMILES: {src_sentence}")
             return [("[Decoding Error - Empty Output]", 0.0)]
    else:
        logging.error(f"Unknown decoding strategy: {decoding_strategy}")
        return [("[Error: Unknown Strategy]", 0.0)]


    # --- Decode Generated Tokens ---
    translations = []
    for tgt_tokens_tensor, score in results:
        if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
            translations.append(("[Decoding Error - Empty Sequence]", score))
            continue

        tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
        try:
            # Decode using the target tokenizer, skipping special tokens
            translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
            translations.append((translation, score))
        except Exception as e:
            logging.error(
                f"Error decoding target tokens {tgt_tokens}: {e}",
                exc_info=True,
            )
            translations.append(("[Decoding Error]", score))

    return translations


# --- Model/Tokenizer Loading Function (Unchanged from previous version) ---
def load_model_and_tokenizers():
    """Loads tokenizers, config, and model from Hugging Face Hub."""
    global model, smiles_tokenizer, iupac_tokenizer, device, config
    if model is not None:
        logging.info("Model and tokenizers already loaded.")
        return

    logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
    try:
        # Determine device (Force CPU for stability in typical Space envs, uncomment cuda if needed)
        # if torch.cuda.is_available():
        #     device = torch.device("cuda")
        #     logging.info("CUDA available, using GPU.")
        # else:
        device = torch.device("cpu")
        logging.info("Using CPU. Modify code to enable GPU if available and desired.")

        # Download files
        logging.info("Downloading files from Hugging Face Hub...")
        cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
        os.makedirs(cache_dir, exist_ok=True)
        logging.info(f"Using cache directory: {cache_dir}")

        try:
            checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir)
            smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir)
            iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir)
            config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir)
            # Ensure enhanced_trainer.py is downloaded or present
            try:
                 hf_hub_download(repo_id=MODEL_REPO_ID, filename="enhanced_trainer.py", cache_dir=cache_dir, local_dir=".") # Download to current dir
                 logging.info("Downloaded enhanced_trainer.py")
            except Exception as download_err:
                 if os.path.exists("enhanced_trainer.py"):
                     logging.warning(f"Could not download enhanced_trainer.py (maybe private?), but found local file. Using local. Error: {download_err}")
                 else:
                     raise download_err # Re-raise if not found locally either

            logging.info("Files downloaded successfully.")
        except Exception as e:
            logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True)
            raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}")

        # Load config
        logging.info("Loading configuration...")
        try:
            with open(config_path, "r") as f:
                config = json.load(f)
            logging.info("Configuration loaded.")
            required_keys = ["src_vocab_size", "tgt_vocab_size", "emb_size", "nhead", "ffn_hid_dim", "num_encoder_layers", "num_decoder_layers", "dropout", "max_len", "pad_token_id", "bos_token_id", "eos_token_id"]
            missing_keys = [key for key in required_keys if config.get(key) is None]
            if missing_keys:
                raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}.")
            logging.info(f"Using config: { {k: config.get(k) for k in required_keys} }") # Log key values
        except Exception as e:
            logging.error(f"Error loading or validating config: {e}", exc_info=True)
            raise gr.Error(f"Config Error: {e}")


        # Load tokenizers
        logging.info("Loading tokenizers...")
        try:
            smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
            iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
            # Basic validation (can add more checks as before)
            if smiles_tokenizer.get_vocab_size() != config['src_vocab_size']:
                 logging.warning(f"SMILES vocab size mismatch: Tokenizer={smiles_tokenizer.get_vocab_size()}, Config={config['src_vocab_size']}")
            if iupac_tokenizer.get_vocab_size() != config['tgt_vocab_size']:
                 logging.warning(f"IUPAC vocab size mismatch: Tokenizer={iupac_tokenizer.get_vocab_size()}, Config={config['tgt_vocab_size']}")
            logging.info("Tokenizers loaded.")
        except Exception as e:
            logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
            raise gr.Error(f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}")

        # Load model
        logging.info("Loading model from checkpoint...")
        try:
            # Ensure config keys match expected arguments of SmilesIupacLitModule.__init__
            # Map config keys if necessary, e.g., if config uses 'vocab_size_src' but class expects 'src_vocab_size'
            model_hparams = config.copy() # Start with all config params

            # Example remapping (adjust if your config/class names differ):
            # model_hparams['src_vocab_size'] = model_hparams.pop('vocab_size_src', config['src_vocab_size'])
            # model_hparams['tgt_vocab_size'] = model_hparams.pop('vocab_size_tgt', config['tgt_vocab_size'])
            # model_hparams['bos_idx'] = model_hparams.pop('bos_token_id', config['bos_token_id'])
            # model_hparams['eos_idx'] = model_hparams.pop('eos_token_id', config['eos_token_id'])
            # model_hparams['padding_idx'] = model_hparams.pop('pad_token_id', config['pad_token_id'])

            # Remove keys from hparams that are not expected by the LitModule's __init__
            # This depends on the exact signature of SmilesIupacLitModule
            # Common ones to potentially remove if not direct args: max_len (often used elsewhere)
            # Check the __init__ signature of SmilesIupacLitModule in enhanced_trainer.py
            expected_args = SmilesIupacLitModule.__init__.__code__.co_varnames
            hparams_to_pass = {k: v for k, v in model_hparams.items() if k in expected_args}
            logging.info(f"Passing hparams to LitModule: {hparams_to_pass.keys()}")


            model = SmilesIupacLitModule.load_from_checkpoint(
                checkpoint_path,
                map_location=device,
                # devices=1, # Often not needed for inference loading
                strict=False, # Set to False initially if encountering key errors
                **hparams_to_pass # Pass relevant hparams from config
            )

            model.to(device)
            model.eval()
            model.freeze()
            logging.info(f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'.")

        except FileNotFoundError:
            logging.error(f"Checkpoint file not found: {checkpoint_path}")
            raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
        except Exception as e:
            logging.error(f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
            if "size mismatch" in str(e):
                error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
                logging.error(error_detail)
                raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
            elif "unexpected keyword argument" in str(e) or "missing 1 required positional argument" in str(e):
                 error_detail = f"Mismatch between config.json keys and SmilesIupacLitModule constructor arguments. Check enhanced_trainer.py and config.json. Error: {e}"
                 logging.error(error_detail)
                 raise gr.Error(f"Model Error: {error_detail}")
            elif "memory" in str(e).lower():
                logging.warning("Potential OOM error during model loading.")
                gc.collect()
                torch.cuda.empty_cache() if device.type == "cuda" else None
                raise gr.Error(f"Model Error: OOM loading model. Check Space resources. Error: {e}")
            else:
                raise gr.Error(f"Model Error: Failed to load checkpoint. Check logs. Error: {e}")

    except gr.Error:
        raise # Propagate Gradio errors directly
    except Exception as e:
        logging.error(f"Unexpected error during loading: {e}", exc_info=True)
        raise gr.Error(f"Initialization Error: Unexpected error. Check logs. Error: {e}")


# --- Inference Function for Gradio ---
@spaces.GPU # Uncomment if using GPU and have appropriate hardware tier
def predict_iupac(smiles_string, decoding_strategy, num_beams, num_return_sequences):
    """
    Performs SMILES to IUPAC translation using the loaded model and selected strategy.
    """
    global model, smiles_tokenizer, iupac_tokenizer, device, config

    # --- Input Validation ---
    if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
        error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
        logging.error(error_msg)
        return f"Initialization Error: {error_msg}"

    if not smiles_string or not smiles_string.strip():
        return "Error: Please enter a valid SMILES string."

    smiles_input = smiles_string.strip()

    # Validate SMILES using RDKit
    try:
        mol = MolFromSmiles(smiles_input)
        if mol is None:
            return f"Error: Invalid SMILES string provided: '{smiles_input}'"
        smiles_input = CanonSmiles(smiles_input) # Use canonical form
        logging.info(f"Canonical SMILES: {smiles_input}")
    except Exception as e:
        logging.error(f"Error during SMILES validation/canonicalization: {e}", exc_info=True)
        return f"Error: Could not process SMILES string '{smiles_input}'. RDKit error: {e}"

    # Validate beam search parameters
    if decoding_strategy == "Beam Search":
        if not isinstance(num_beams, int) or num_beams <= 0:
            return "Error: Beam width must be a positive integer."
        if not isinstance(num_return_sequences, int) or num_return_sequences <= 0:
            return "Error: Number of return sequences must be a positive integer."
        if num_return_sequences > num_beams:
            return f"Error: Number of return sequences ({num_return_sequences}) cannot exceed beam width ({num_beams})."
    else:
        # Ensure defaults are used for greedy
        num_beams = 1
        num_return_sequences = 1


    try:
        # --- Call the core translation logic ---
        sos_idx = config["bos_token_id"]
        eos_idx = config["eos_token_id"]
        pad_idx = config["pad_token_id"]
        gen_max_len = config["max_len"]
        # Use fixed length penalty for now, could be another slider
        length_penalty = 0.6

        predicted_results = translate( # Returns list of (name, score)
            model=model,
            src_sentence=smiles_input,
            smiles_tokenizer=smiles_tokenizer,
            iupac_tokenizer=iupac_tokenizer,
            device=device,
            max_len=gen_max_len,
            sos_idx=sos_idx,
            eos_idx=eos_idx,
            pad_idx=pad_idx,
            decoding_strategy=decoding_strategy,
            beam_width=num_beams,
            num_return_sequences=num_return_sequences,
            length_penalty_alpha=length_penalty,
        )
        logging.info(f"Prediction returned {len(predicted_results)} result(s). Strategy: {decoding_strategy}, Beams: {num_beams}, Return: {num_return_sequences}")

        # --- Format Output ---
        output_lines = []
        output_lines.append(f"Input SMILES: {smiles_input}")
        output_lines.append(f"Decoding Strategy: {decoding_strategy}")
        if decoding_strategy == "Beam Search":
            output_lines.append(f"Beam Width: {num_beams}")
            output_lines.append(f"Returned Sequences: {len(predicted_results)}")
            output_lines.append(f"Length Penalty Alpha: {length_penalty:.2f}")


        output_lines.append("\n--- Predictions ---")

        if not predicted_results:
             output_lines.append("No predictions generated.")
        else:
            for i, (name, score) in enumerate(predicted_results):
                 if "[Error]" in name or not name:
                     output_lines.append(f"{i+1}. Prediction Failed: {name}")
                 else:
                     score_info = f"(Score: {score:.4f})" if decoding_strategy == "Beam Search" else ""
                     output_lines.append(f"{i+1}. {name} {score_info}")

        return "\n".join(output_lines)

    except RuntimeError as e:
        logging.error(f"Runtime error during translation: {e}", exc_info=True)
        gc.collect()
        if device.type == 'cuda': torch.cuda.empty_cache()
        return f"Runtime Error during translation: {e}. Check logs."
    except Exception as e:
        logging.error(f"Unexpected error during translation: {e}", exc_info=True)
        return f"Unexpected Error during translation: {e}. Check logs."


# --- Load Model on App Start ---
try:
    load_model_and_tokenizers()
except gr.Error as ge:
    # Log the Gradio error but allow interface to load potentially showing the error message
    logging.error(f"Gradio Initialization Error during load: {ge}")
    # Display error in the UI if possible? Hard to do before UI is built.
    # We rely on the predict function checking for loaded components.
except Exception as e:
    logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
    # This might prevent the app from starting correctly.


# --- Create Gradio Interface ---
title = "SMILES to IUPAC Name Translator"
description = f"""
Translate a SMILES string into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}).
Choose between **Greedy Decoding** (fastest, picks the most likely next word) and **Beam Search Decoding** (explores multiple possibilities, potentially better results, slower).
**Note:** Model loaded on **{str(device).upper() if device else 'N/A'}**. Beam search can be slow, especially with larger beam widths.
Check `config.json` in the repo for model details. SMILES input will be canonicalized using RDKit.
"""

# Use gr.Blocks for more layout control
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
    gr.Markdown(f"# {title}")
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=1): # Input column
            smiles_input = gr.Textbox(
                label="SMILES String",
                placeholder="Enter SMILES string (e.g., CCO, c1ccccc1)",
                lines=2,
            )
            with gr.Accordion("Decoding Options", open=False): # Options collapsible
                 decode_strategy = gr.Radio(
                     ["Greedy", "Beam Search"],
                     label="Decoding Strategy",
                     value="Greedy",
                     info="Greedy is faster, Beam Search may be more accurate."
                 )
                 beam_width_slider = gr.Slider(
                     minimum=1,
                     maximum=20, # Keep max reasonable for performance
                     step=1,
                     value=5,
                     label="Beam Width",
                     info="Number of beams to explore (Beam Search only)",
                     visible=False # Initially hidden
                 )
                 num_seq_slider = gr.Slider(
                     minimum=1,
                     maximum=5, # Keep max reasonable
                     step=1,
                     value=1,
                     label="Number of Results",
                     info="How many sequences to return (Beam Search only)",
                     visible=False # Initially hidden
                 )

            submit_btn = gr.Button("Translate", variant="primary")

            # --- Logic to show/hide beam search options ---
            def update_beam_options(strategy):
                is_beam = strategy == "Beam Search"
                return {
                    beam_width_slider: gr.update(visible=is_beam),
                    num_seq_slider: gr.update(visible=is_beam)
                }

            decode_strategy.change(
                fn=update_beam_options,
                inputs=decode_strategy,
                outputs=[beam_width_slider, num_seq_slider]
            )


        with gr.Column(scale=2): # Output column
            output_text = gr.Textbox(
                label="Translation Results",
                lines=10, # More lines for potentially multiple results
                show_copy_button=True,
                # interactive=False # Output shouldn't be user-editable
            )

    # --- Define Event Listeners ---
    submit_btn.click(
        fn=predict_iupac,
        inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider],
        outputs=output_text,
        api_name="translate_smiles"
    )

    # Trigger on Enter press in the SMILES box
    smiles_input.submit(
        fn=predict_iupac,
        inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider],
        outputs=output_text
    )

    # Add examples
    gr.Examples(
        examples=[
            ["CCO", "Greedy", 1, 1],
            ["c1ccccc1", "Greedy", 1, 1],
            ["CC(C)Br", "Beam Search", 5, 3],
            ["C[C@H](O)c1ccccc1", "Beam Search", 10, 5],
            ["INVALID_SMILES", "Greedy", 1, 1], # Example of invalid input
            ["N#CC(C)(C)OC(=O)C(C)=C", "Beam Search", 8, 2]
        ],
        inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], # Match inputs order
        outputs=output_text, # Output component
        fn=predict_iupac, # Function to run for examples
        cache_examples=False, # Caching might be tricky with model state
        label="Example SMILES & Settings"
    )


# --- Launch the App ---
if __name__ == "__main__":
    iface.launch()