AdrianM0 commited on
Commit
9bd549c
·
verified ·
1 Parent(s): f047f18

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +186 -494
app.py CHANGED
@@ -1,27 +1,24 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
- import torch.nn.functional as F # Needed for beam search log_softmax
5
- import pytorch_lightning as pl # Needed for LightningModule and loading
6
  import os
7
  import json
8
  import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
- import gc # For garbage collection on potential OOM
12
- import math # Potentially needed by imported classes
13
 
14
  # --- Configuration ---
15
- # Ensure these match the files uploaded to your Hugging Face Hub repository
16
  MODEL_REPO_ID = (
17
  "AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
18
  )
19
- CHECKPOINT_FILENAME = "last.ckpt" # Or "best_model.ckpt" or whatever you uploaded
20
  SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
21
  IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
22
- CONFIG_FILENAME = (
23
- "config.json" # Assumes you saved hparams to config.json during/after training
24
- )
25
  # --- End Configuration ---
26
 
27
  # --- Logging ---
@@ -31,21 +28,13 @@ logging.basicConfig(
31
 
32
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
33
  try:
34
- # We need the LightningModule definition and the mask function
35
- # Ensure enhanced_trainer.py is present in the root of your HF Repo
36
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
37
-
38
  logging.info("Successfully imported from enhanced_trainer.py.")
39
-
40
- # REMOVED: Redundant import from test_ckpt as functions are defined below
41
- # from test_ckpt import beam_search_decode, translate
42
-
43
  except ImportError as e:
44
  logging.error(
45
  f"Failed to import helper code from enhanced_trainer.py: {e}. "
46
  f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
47
  )
48
- # Raise error visible in Gradio UI and logs
49
  raise gr.Error(
50
  f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
51
  )
@@ -65,27 +54,21 @@ device: torch.device | None = None
65
  config: dict | None = None
66
 
67
 
68
- # --- Beam Search Decoding Logic (Locally defined) ---
69
- def beam_search_decode(
70
  model: pl.LightningModule,
71
  src: torch.Tensor,
72
  src_padding_mask: torch.Tensor,
73
  max_len: int,
74
  sos_idx: int,
75
  eos_idx: int,
76
- pad_idx: int,
77
  device: torch.device,
78
- beam_width: int = 5,
79
- n_best: int = 5,
80
- length_penalty: float = 0.6,
81
- ) -> list[torch.Tensor]:
82
  """
83
- Performs beam search decoding using the LightningModule's model.
84
- (Ensures this code is self-contained within app.py or correctly imported)
85
  """
86
  model.eval() # Ensure model is in evaluation mode
87
  transformer_model = model.model # Access the underlying Seq2SeqTransformer
88
- n_best = min(n_best, beam_width)
89
 
90
  try:
91
  with torch.no_grad():
@@ -96,157 +79,66 @@ def beam_search_decode(
96
  memory = memory.to(device)
97
  memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
98
 
99
- # --- Initialize Beams ---
100
- initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
101
- sos_idx
102
- ) # [1, 1]
103
- initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
104
- active_beams = [(initial_beam_seq, initial_beam_score)]
105
- finished_beams = []
106
 
107
  # --- Decoding Loop ---
108
- for step in range(max_len - 1):
109
- if not active_beams:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  break
111
 
112
- potential_next_beams = []
113
- for current_seq, current_score in active_beams:
114
- # Check if the beam already ended
115
- if current_seq[0, -1].item() == eos_idx:
116
- # If already finished, add directly to finished beams and skip expansion
117
- finished_beams.append((current_seq, current_score))
118
- continue
119
-
120
- # Prepare inputs for the decoder
121
- tgt_input = current_seq # [1, current_len]
122
- tgt_seq_len = tgt_input.shape[1]
123
- tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
124
- device
125
- ) # [curr_len, curr_len]
126
- # No padding in target during generation yet
127
- tgt_padding_mask = torch.zeros(
128
- tgt_input.shape, dtype=torch.bool, device=device
129
- ) # [1, curr_len]
130
-
131
- # Decode one step
132
- decoder_output = transformer_model.decode(
133
- tgt=tgt_input,
134
- memory=memory,
135
- tgt_mask=tgt_mask,
136
- tgt_padding_mask=tgt_padding_mask,
137
- memory_key_padding_mask=memory_key_padding_mask,
138
- ) # [1, curr_len, emb_size]
139
-
140
- # Get logits for the *next* token prediction
141
- next_token_logits = transformer_model.generator(
142
- decoder_output[
143
- :, -1, :
144
- ] # Use output corresponding to the last input token
145
- ) # [1, tgt_vocab_size]
146
-
147
- # Calculate log probabilities and add current beam score
148
- log_probs = F.log_softmax(
149
- next_token_logits, dim=-1
150
- ) # [1, tgt_vocab_size]
151
- combined_scores = (
152
- log_probs + current_score
153
- ) # Add score of the current path
154
-
155
- # Find top k candidates for the *next* step
156
- topk_log_probs, topk_indices = torch.topk(
157
- combined_scores, beam_width, dim=-1
158
- ) # [1, beam_width], [1, beam_width]
159
-
160
- # Expand potential beams
161
- for i in range(beam_width):
162
- next_token_id = topk_indices[0, i].item()
163
- # Score is the cumulative log probability of the new sequence
164
- next_score = topk_log_probs[0, i].reshape(
165
- 1
166
- ) # Keep as tensor [1]
167
- next_token_tensor = torch.tensor(
168
- [[next_token_id]], dtype=torch.long, device=device
169
- ) # [1, 1]
170
- new_seq = torch.cat(
171
- [current_seq, next_token_tensor], dim=1
172
- ) # [1, current_len + 1]
173
- potential_next_beams.append((new_seq, next_score))
174
-
175
- # --- Prune Beams ---
176
- # Sort all potential next beams by score
177
- potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
178
-
179
- # Select the top `beam_width` beams for the next iteration
180
- active_beams = []
181
- temp_finished_beams = [] # Collect beams finished in *this* step
182
- for seq, score in potential_next_beams:
183
- if (
184
- len(active_beams) >= beam_width
185
- and len(temp_finished_beams) >= beam_width
186
- ):
187
- break # Optimization: Stop if we have enough active and finished candidates
188
-
189
- is_finished = seq[0, -1].item() == eos_idx
190
- if is_finished:
191
- # Add to temporary finished list for this step
192
- if len(temp_finished_beams) < beam_width:
193
- temp_finished_beams.append((seq, score))
194
- elif len(active_beams) < beam_width:
195
- # Add to active beams for next step
196
- active_beams.append((seq, score))
197
-
198
- # Add the newly finished beams to the main finished list
199
- finished_beams.extend(temp_finished_beams)
200
- # Optional: Prune finished_beams if it grows too large (e.g., keep top 2*beam_width)
201
- finished_beams.sort(key=lambda x: x[1].item(), reverse=True)
202
- finished_beams = finished_beams[
203
- : beam_width * 2
204
- ] # Keep a reasonable number
205
-
206
- # --- Final Selection ---
207
- # Add any remaining active beams (which didn't finish) to the finished list
208
- finished_beams.extend(active_beams)
209
-
210
- # Apply length penalty and sort
211
- def get_score_with_penalty(beam_tuple):
212
- seq, score = beam_tuple
213
- seq_len = seq.shape[1]
214
- # Avoid division by zero or negative exponent issues
215
- if length_penalty <= 0.0 or seq_len <= 1:
216
- return score.item()
217
- else:
218
- # Length penalty calculation
219
- penalty = (
220
- (5.0 + float(seq_len)) / 6.0
221
- ) ** length_penalty # Common formula
222
- return score.item() / penalty
223
- # Alternative simpler penalty:
224
- # return score.item() / (float(seq_len) ** length_penalty)
225
-
226
- finished_beams.sort(
227
- key=get_score_with_penalty, reverse=True
228
- ) # Higher score is better
229
-
230
- # Return the top n_best sequences (excluding the initial SOS token)
231
- top_sequences = [
232
- seq[:, 1:]
233
- for seq, score in finished_beams[:n_best]
234
- if seq.shape[1] > 1 # Ensure seq not just SOS
235
- ] # seq shape [1, len] -> [1, len-1]
236
- return top_sequences
237
 
238
  except RuntimeError as e:
239
- logging.error(f"Runtime error during beam search decode: {e}", exc_info=True)
240
  if "CUDA out of memory" in str(e) and device.type == "cuda":
241
  gc.collect()
242
  torch.cuda.empty_cache()
243
- return [] # Return empty list on error
244
  except Exception as e:
245
- logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
246
- return []
247
 
248
 
249
- # --- Translation Function (Locally defined) ---
250
  def translate(
251
  model: pl.LightningModule,
252
  src_sentence: str,
@@ -257,95 +149,71 @@ def translate(
257
  sos_idx: int,
258
  eos_idx: int,
259
  pad_idx: int,
260
- beam_width: int = 5,
261
- n_best: int = 5,
262
- length_penalty: float = 0.6,
263
- ) -> list[str]:
264
  """
265
- Translates a single SMILES string using beam search.
266
- (Ensures this code is self-contained within app.py or correctly imported)
267
  """
268
  model.eval() # Ensure model is in eval mode
269
- translations = []
270
- n_best = min(n_best, beam_width) # Can't return more than beam width
271
 
272
  # --- Tokenize Source ---
273
  try:
274
  # Ensure tokenizer has truncation/padding configured if needed, or handle manually
275
- smiles_tokenizer.enable_truncation(max_length=max_len)
276
  src_encoded = smiles_tokenizer.encode(src_sentence)
277
  if not src_encoded or not src_encoded.ids:
278
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
279
- return ["[Encoding Error]"] * n_best
280
  # Use the truncated IDs directly
281
  src_ids = src_encoded.ids
282
- # Note: max_len here applies to source *tokenizer*, generation length is separate
283
  except Exception as e:
284
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
285
- return ["[Encoding Error]"] * n_best
286
 
287
  # --- Prepare Input Tensor and Mask ---
288
  src = (
289
  torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
290
  ) # [1, src_len]
291
- # Create padding mask (True where it's a pad token, should be all False here)
292
  src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
293
 
294
- # --- Perform Beam Search Decoding ---
295
- # Calls the beam_search_decode function defined *above in this file*
296
  # Note: max_len for generation should come from config if it dictates output length
297
  generation_max_len = config.get(
298
  "max_len", 256
299
  ) # Use config max_len for output limit
300
- tgt_tokens_list = beam_search_decode(
301
  model=model,
302
  src=src,
303
  src_padding_mask=src_padding_mask,
304
  max_len=generation_max_len, # Use generation limit
305
  sos_idx=sos_idx,
306
  eos_idx=eos_idx,
307
- pad_idx=pad_idx,
308
  device=device,
309
- beam_width=beam_width,
310
- n_best=n_best,
311
- length_penalty=length_penalty,
312
- ) # Returns list of tensors
313
 
314
  # --- Decode Generated Tokens ---
315
- if not tgt_tokens_list:
316
- logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
317
- # Provide n_best error messages
318
- return ["[Decoding Error - Empty Output]"] * n_best
319
-
320
- for i, tgt_tokens_tensor in enumerate(tgt_tokens_list):
321
- if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0:
322
- tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
323
- try:
324
- # Decode using the target tokenizer, skipping special tokens
325
- translation = iupac_tokenizer.decode(
326
- tgt_tokens, skip_special_tokens=True
327
- )
328
- translations.append(translation)
329
- except Exception as e:
330
- logging.error(
331
- f"Error decoding target tokens {tgt_tokens} for beam {i}: {e}",
332
- exc_info=True,
333
- )
334
- translations.append("[Decoding Error]")
335
- else:
336
- logging.warning(
337
- f"Beam {i} result was empty or None for SMILES: {src_sentence}"
338
- )
339
- translations.append("[Decoding Error - Empty Tensor]")
340
-
341
- # Pad with error messages if fewer than n_best results were generated
342
- while len(translations) < n_best:
343
- translations.append("[Decoding Error - Fewer Results]")
344
 
345
- return translations
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
 
348
- # --- Model/Tokenizer Loading Function ---
349
  def load_model_and_tokenizers():
350
  """Loads tokenizers, config, and model from Hugging Face Hub."""
351
  global model, smiles_tokenizer, iupac_tokenizer, device, config
@@ -355,27 +223,22 @@ def load_model_and_tokenizers():
355
 
356
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
357
  try:
358
- # Determine device - Use CPU for Gradio Spaces unless GPU is explicitly available and desired
359
- # For simplicity and broader compatibility on free tier Spaces, CPU is safer.
360
  if torch.cuda.is_available():
361
  logging.warning(
362
  "CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
363
  )
364
  device = torch.device("cpu")
365
- # Uncomment below and comment above line to try using GPU if available
366
  # device = torch.device("cuda")
367
  # logging.info("CUDA available, using GPU.")
368
  else:
369
  device = torch.device("cpu")
370
  logging.info("CUDA not available, using CPU.")
371
 
372
- # Download files from HF Hub
373
  logging.info("Downloading files from Hugging Face Hub...")
374
  try:
375
- # Use cache directory for Spaces persistence if possible
376
- cache_dir = os.environ.get(
377
- "GRADIO_CACHE", "./hf_cache"
378
- ) # Gradio sets cache dir
379
  os.makedirs(cache_dir, exist_ok=True)
380
  logging.info(f"Using cache directory: {cache_dir}")
381
 
@@ -398,7 +261,7 @@ def load_model_and_tokenizers():
398
  logging.info("Files downloaded successfully.")
399
  except Exception as e:
400
  logging.error(
401
- f"Failed to download files from {MODEL_REPO_ID}. Check filenames ({CHECKPOINT_FILENAME}, {SMILES_TOKENIZER_FILENAME}, etc.) and repo status. Error: {e}",
402
  exc_info=True,
403
  )
404
  raise gr.Error(
@@ -412,100 +275,46 @@ def load_model_and_tokenizers():
412
  config = json.load(f)
413
  logging.info("Configuration loaded.")
414
  # --- Validate essential config keys ---
415
- # Use hparams logged during training if available, map them carefully
416
- # These keys are based on SmilesIupacLitModule and Seq2SeqTransformer init args
417
- # Mappings might be needed if keys in config.json differ from these exact names
418
  required_keys = [
419
- # Need vocab sizes used during *training* for loading
420
- "actual_src_vocab_size", # Assuming this was saved in hparams
421
- "actual_tgt_vocab_size", # Assuming this was saved in hparams
422
- # Model architecture params
423
- "emb_size",
424
- "nhead",
425
- "ffn_hid_dim",
426
- "num_encoder_layers",
427
- "num_decoder_layers",
428
- "dropout",
429
- "max_len", # Needed for generation limit and tokenizer setting
430
- # Special token IDs needed for generation
431
- # Assuming standard names, adjust if your config uses different keys
432
- "pad_token_id", # Often 0
433
- "bos_token_id", # Often 1 (used as SOS)
434
- "eos_token_id", # Often 2
435
  ]
436
- # Remap keys if necessary (e.g., if config.json uses 'src_vocab_size' instead of 'actual_src_vocab_size')
437
  config_key_mapping = {
438
- "actual_src_vocab_size": config.get(
439
- "actual_src_vocab_size", config.get("src_vocab_size")
440
- ),
441
- "actual_tgt_vocab_size": config.get(
442
- "actual_tgt_vocab_size", config.get("tgt_vocab_size")
443
- ),
444
- "emb_size": config.get("emb_size"),
445
- "nhead": config.get("nhead"),
446
- "ffn_hid_dim": config.get("ffn_hid_dim"),
447
- "num_encoder_layers": config.get("num_encoder_layers"),
448
- "num_decoder_layers": config.get("num_decoder_layers"),
449
- "dropout": config.get("dropout"),
450
- "max_len": config.get("max_len"),
451
- "pad_token_id": config.get(
452
- "pad_token_id"
453
- ), # Use default if missing? Risky.
454
- "bos_token_id": config.get(
455
- "bos_token_id"
456
- ), # Use default if missing? Risky.
457
- "eos_token_id": config.get(
458
- "eos_token_id"
459
- ), # Use default if missing? Risky.
460
  }
461
- # Update config with potentially remapped values
462
  config.update(config_key_mapping)
463
 
464
  missing_keys = [key for key in required_keys if config.get(key) is None]
465
  if missing_keys:
466
- # Try to load defaults for token IDs if absolutely necessary, but warn heavily
467
- defaults_used = []
468
- # Re-check missing keys after attempting defaults
469
- missing_keys = [key for key in required_keys if config.get(key) is None]
470
- if missing_keys:
471
- raise ValueError(
472
- f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
473
- f"Ensure these were saved in the hyperparameters during training."
474
- )
475
- else:
476
- logging.warning(
477
- f"Config file was missing keys, used defaults for: {defaults_used}. This might be incorrect!"
478
- )
479
-
480
- # Log the final config values being used
481
  logging.info(
482
- f"Using config values: src_vocab={config['actual_src_vocab_size']}, tgt_vocab={config['actual_tgt_vocab_size']}, "
483
  f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
484
  f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
485
  )
486
 
487
  except FileNotFoundError:
488
- logging.error(
489
- f"Config file not found locally after download attempt: {config_path}"
490
- )
491
- raise gr.Error(
492
- f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo."
493
- )
494
  except json.JSONDecodeError as e:
495
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
496
- raise gr.Error(
497
- f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}"
498
- )
499
- except ValueError as e: # Catch our custom validation error
500
  logging.error(f"Config validation error: {e}")
501
  raise gr.Error(f"Config Error: {e}")
502
- except Exception as e: # Catch other potential errors during config processing
503
- logging.error(
504
- f"Unexpected error loading or validating config: {e}", exc_info=True
505
- )
506
- raise gr.Error(
507
- f"Config Error: Unexpected error processing config. Check logs. Error: {e}"
508
- )
509
 
510
  # Load tokenizers
511
  logging.info("Loading tokenizers...")
@@ -513,309 +322,192 @@ def load_model_and_tokenizers():
513
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
514
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
515
  logging.info("Tokenizers loaded.")
516
-
517
- # --- Validate Tokenizer Special Tokens Against Config ---
518
  pad_token = "<pad>"
519
  sos_token = "<sos>"
520
  eos_token = "<eos>"
521
  unk_token = "<unk>"
522
-
523
  issues = []
524
- if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
525
- issues.append(
526
- f"SMILES PAD ID mismatch (tokenizer={smiles_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})"
527
- )
528
- if smiles_tokenizer.token_to_id(unk_token) is None:
529
- issues.append("SMILES UNK token not found")
530
-
531
- if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
532
- issues.append(
533
- f"IUPAC PAD ID mismatch (tokenizer={iupac_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})"
534
- )
535
- if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
536
- issues.append(
537
- f"IUPAC SOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(sos_token)}, config={config['bos_token_id']})"
538
- )
539
- if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
540
- issues.append(
541
- f"IUPAC EOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(eos_token)}, config={config['eos_token_id']})"
542
- )
543
- if iupac_tokenizer.token_to_id(unk_token) is None:
544
- issues.append("IUPAC UNK token not found")
545
-
546
- if issues:
547
- logging.warning(
548
- "Tokenizer validation issues detected: " + "; ".join(issues)
549
- )
550
- # Decide if this is fatal or just a warning
551
- # raise gr.Error("Tokenizer Error: Special token IDs mismatch config. Check tokenizers and config.json.") # Make it fatal if IDs must match
552
 
553
  except Exception as e:
554
- logging.error(
555
- f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}",
556
- exc_info=True,
557
- )
558
- raise gr.Error(
559
- f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}"
560
- )
561
 
562
  # Load model
563
  logging.info("Loading model from checkpoint...")
564
  try:
565
- # Load the LightningModule state dict
566
- # Use the actual vocab sizes and hparams from the loaded config
567
  model = SmilesIupacLitModule.load_from_checkpoint(
568
  checkpoint_path,
 
569
  src_vocab_size=config["src_vocab_size"],
570
  tgt_vocab_size=config["tgt_vocab_size"],
571
- hparams_dict=config, # Pass the loaded config as hparams
572
- map_location=device, # Map model to the chosen device (CPU or CUDA)
573
- strict=True, # Be less strict about matching keys, useful for PTL versions or minor changes
 
 
574
  )
575
 
576
- # Ensure model is on the correct device, in eval mode, and frozen
577
  model.to(device)
578
  model.eval()
579
- model.freeze() # Disables gradient calculations
580
  logging.info(
581
  f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
582
  )
583
 
584
  except FileNotFoundError:
585
- logging.error(
586
- f"Checkpoint file not found locally after download attempt: {checkpoint_path}"
587
- )
588
- raise gr.Error(
589
- f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
590
- )
591
  except Exception as e:
592
- logging.error(
593
- f"Error loading model from checkpoint {checkpoint_path}: {e}",
594
- exc_info=True,
595
- )
596
- # Check for common errors
597
  if "size mismatch" in str(e):
598
- error_detail = (
599
- f"Potential size mismatch. Check if vocab sizes in config.json ({config.get('actual_src_vocab_size')}, "
600
- f"{config.get('actual_tgt_vocab_size')}) match the loaded checkpoint's embedding layers."
601
- )
602
  logging.error(error_detail)
603
  raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
604
  elif "memory" in str(e).lower():
605
- logging.warning("Potential Out-of-Memory error during model loading.")
606
- gc.collect()
607
- if device.type == "cuda":
608
- torch.cuda.empty_cache()
609
- raise gr.Error(
610
- f"Model Error: Out of memory loading model. Check Space resources. Error: {e}"
611
- )
612
  else:
613
- raise gr.Error(
614
- f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}"
615
- )
616
 
617
- except gr.Error: # Re-raise Gradio errors to be displayed
618
- raise
619
- except Exception as e: # Catch any other unexpected errors
620
- logging.error(
621
- f"Unexpected error during model/tokenizer loading: {e}", exc_info=True
622
- )
623
- raise gr.Error(
624
- f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}"
625
- )
626
 
627
 
628
- # --- Inference Function for Gradio ---
629
- def predict_iupac(smiles_string, beam_width_str, n_best_str):
630
  """
631
- Performs SMILES to IUPAC translation using the loaded model and beam search.
632
- Takes string inputs from Gradio sliders/inputs and converts them.
633
  """
634
  global model, smiles_tokenizer, iupac_tokenizer, device, config
635
 
636
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
637
  error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
638
  logging.error(error_msg)
639
- # Try to determine n_best for error output formatting
640
- try:
641
- n_best_int = int(n_best_str)
642
- except:
643
- n_best_int = 1
644
- return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
645
 
646
  if not smiles_string or not smiles_string.strip():
647
  error_msg = "Error: Please enter a valid SMILES string."
648
- try:
649
- n_best_int = int(n_best_str)
650
- except:
651
- n_best_int = 1
652
- return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
653
 
654
  smiles_input = smiles_string.strip()
655
 
656
- # --- Safely parse numerical inputs ---
657
  try:
658
- beam_width = int(beam_width_str)
659
- n_best = int(n_best_str)
660
- if beam_width < 1 or n_best < 1 or n_best > beam_width:
661
- raise ValueError(
662
- "Beam width and n_best must be >= 1, and n_best <= beam width."
663
- )
664
- except ValueError as e:
665
- error_msg = f"Error: Invalid input parameter ({e}). Please check beam width, n_best, and length penalty values."
666
- logging.error(error_msg)
667
- # Cannot determine n_best if its input was invalid, default to 1 error line
668
- return f"1. {error_msg}"
669
-
670
- try:
671
- # --- Call the core translation logic ---
672
- # Retrieve necessary IDs from the loaded config
673
  sos_idx = config["bos_token_id"]
674
  eos_idx = config["eos_token_id"]
675
  pad_idx = config["pad_token_id"]
676
- gen_max_len = config["max_len"] # Max length for generation
677
 
678
- predicted_names = translate(
679
  model=model,
680
  src_sentence=smiles_input,
681
  smiles_tokenizer=smiles_tokenizer,
682
  iupac_tokenizer=iupac_tokenizer,
683
  device=device,
684
- max_len=gen_max_len, # Pass generation length limit
685
  sos_idx=sos_idx,
686
  eos_idx=eos_idx,
687
  pad_idx=pad_idx,
688
- beam_width=beam_width,
689
- n_best=n_best,
690
- length_penalty=0.6,
691
  )
692
- logging.info(f"Predictions returned: {predicted_names}")
693
 
694
  # --- Format Output ---
695
- if not predicted_names:
696
- output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated (beam search might have failed)."
 
 
697
  else:
698
- # Ensure we only display up to n_best results, even if translate returned more/fewer due to errors
699
- display_names = predicted_names[:n_best]
700
  output_text = (
701
  f"Input SMILES: {smiles_input}\n\n"
702
- f"Top {len(display_names)} Predictions (Beam Width={beam_width}, n_best={n_best}):\n"
703
- )
704
- output_text += "\n".join(
705
- [f"{i + 1}. {name}" for i, name in enumerate(display_names)]
706
  )
707
- # Add a note if fewer results than requested were generated
708
- if len(display_names) < n_best:
709
- output_text += f"\n\nNote: Only {len(display_names)} result(s) generated successfully."
710
-
711
  return output_text
712
 
713
  except RuntimeError as e:
714
  logging.error(f"Runtime error during translation: {e}", exc_info=True)
715
  error_msg = f"Runtime Error during translation: {e}"
716
  if "memory" in str(e).lower():
717
- gc.collect()
718
- if device.type == "cuda":
719
- torch.cuda.empty_cache()
720
- error_msg += " (Potential OOM - try reducing beam width or input length)"
721
- # Return n_best error messages
722
- return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
723
 
724
  except Exception as e:
725
  logging.error(f"Unexpected error during translation: {e}", exc_info=True)
726
  error_msg = f"Unexpected Error during translation: {e}"
727
- return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
728
 
729
 
730
  # --- Load Model on App Start ---
731
- # Wrap in try/except to prevent app from crashing completely if loading fails
732
- # The error should be caught and displayed by Gradio via gr.Error raised in the function.
733
  try:
734
  load_model_and_tokenizers()
735
  except gr.Error as ge:
736
  logging.error(f"Gradio Initialization Error: {ge}")
737
- # Gradio handles displaying gr.Error, but we log it too.
738
- # We might want to display a placeholder UI or message if loading fails critically.
739
- pass # Allow Gradio to potentially start with an error message
740
  except Exception as e:
741
- # Catch any non-Gradio errors during the initial load sequence
742
- logging.error(
743
- f"Critical error during initial model loading sequence: {e}", exc_info=True
744
- )
745
- # Optionally raise gr.Error here too, although it might be too late if Gradio hasn't fully initialized.
746
- # raise gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
747
 
748
 
749
- # --- Create Gradio Interface ---
750
- title = "SMILES to IUPAC Name Translator"
751
  description = f"""
752
  Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
753
- Translation uses beam search decoding. Adjust parameters below.
754
- **Note:** Model loaded on **{str(device).upper()}**. Performance may vary. Check `config.json` in the repo for model details.
755
  """
756
 
757
- # Define examples using the input types expected by the interface
758
  examples = [
759
- ["CCO", 5, 3, 0.6], # Ethanol
760
- ["C1=CC=CC=C1", 5, 3, 0.6], # Benzene
761
- ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
762
- ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
763
- # Very complex example - might take time or fail on CPU/low memory
764
- # ["CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=C(C(=N4)C5=CC=CC=C5)C", 8, 1, 0.7], # Gleevec (Imatinib) - simplified SMILES structure
765
- ["INVALID_SMILES", 3, 1, 0.6], # Example of invalid input
766
  ]
767
 
768
- # Ensure input components match the `predict_iupac` function signature order and types
769
  smiles_input = gr.Textbox(
770
  label="SMILES String",
771
  placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
772
  lines=1,
773
  )
774
- # Use number inputs for sliders if direct type casting is desired, but sliders often return float/int anyway
775
- beam_width_input = gr.Slider(
776
- minimum=1,
777
- maximum=10,
778
- value=5,
779
- step=1,
780
- label="Beam Width (k)",
781
- info="Number of sequences kept at each step (higher = more exploration, slower). Affects memory usage.",
782
- )
783
- n_best_input = gr.Slider(
784
- minimum=1,
785
- maximum=10,
786
- value=3,
787
- step=1,
788
- label="Number of Results (n_best)",
789
- info="How many top sequences to return (must be <= Beam Width).",
790
- )
791
 
 
792
  output_text = gr.Textbox(
793
- label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
794
  )
795
 
796
  # Create the interface instance
797
  iface = gr.Interface(
798
  fn=predict_iupac, # The function to call
799
- inputs=[ # List of input components
800
- smiles_input,
801
- beam_width_input,
802
- n_best_input,
803
- ],
804
  outputs=output_text, # Output component
805
  title=title,
806
  description=description,
807
- examples=examples, # Examples to populate the interface
808
- allow_flagging="never", # Disable flagging
809
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"), # Optional theme
810
  article="""
811
  **Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
812
- Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations.
813
- Beam search parameters (width, penalty) significantly impact results and performance.
814
  """,
815
- # Optional: Add live=True for real-time updates as sliders change (can be slow/resource intensive)
816
- # live=False,
817
  )
818
 
819
  # --- Launch the App ---
820
  if __name__ == "__main__":
821
- iface.launch()
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+ # import torch.nn.functional as F # No longer needed for greedy decode directly
5
+ import pytorch_lightning as pl
6
  import os
7
  import json
8
  import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
+ import gc
12
+ import math # Potentially needed by imported classes
13
 
14
  # --- Configuration ---
 
15
  MODEL_REPO_ID = (
16
  "AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
17
  )
18
+ CHECKPOINT_FILENAME = "last.ckpt"
19
  SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
20
  IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
21
+ CONFIG_FILENAME = "config.json"
 
 
22
  # --- End Configuration ---
23
 
24
  # --- Logging ---
 
28
 
29
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
30
  try:
 
 
31
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
 
32
  logging.info("Successfully imported from enhanced_trainer.py.")
 
 
 
 
33
  except ImportError as e:
34
  logging.error(
35
  f"Failed to import helper code from enhanced_trainer.py: {e}. "
36
  f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
37
  )
 
38
  raise gr.Error(
39
  f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
40
  )
 
54
  config: dict | None = None
55
 
56
 
57
+ # --- Greedy Decoding Logic (Locally defined) ---
58
+ def greedy_decode(
59
  model: pl.LightningModule,
60
  src: torch.Tensor,
61
  src_padding_mask: torch.Tensor,
62
  max_len: int,
63
  sos_idx: int,
64
  eos_idx: int,
 
65
  device: torch.device,
66
+ ) -> torch.Tensor:
 
 
 
67
  """
68
+ Performs greedy decoding using the LightningModule's model.
 
69
  """
70
  model.eval() # Ensure model is in evaluation mode
71
  transformer_model = model.model # Access the underlying Seq2SeqTransformer
 
72
 
73
  try:
74
  with torch.no_grad():
 
79
  memory = memory.to(device)
80
  memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
81
 
82
+ # --- Initialize Target Sequence ---
83
+ # Start with the SOS token
84
+ ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) # [1, 1]
 
 
 
 
85
 
86
  # --- Decoding Loop ---
87
+ for _ in range(max_len - 1): # Max length limit
88
+ tgt_seq_len = ys.shape[1]
89
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
90
+ device
91
+ ) # [curr_len, curr_len]
92
+ # No padding in target during generation yet
93
+ tgt_padding_mask = torch.zeros(
94
+ ys.shape, dtype=torch.bool, device=device
95
+ ) # [1, curr_len]
96
+
97
+ # Decode one step
98
+ decoder_output = transformer_model.decode(
99
+ tgt=ys,
100
+ memory=memory,
101
+ tgt_mask=tgt_mask,
102
+ tgt_padding_mask=tgt_padding_mask,
103
+ memory_key_padding_mask=memory_key_padding_mask,
104
+ ) # [1, curr_len, emb_size]
105
+
106
+ # Get logits for the *next* token prediction
107
+ next_token_logits = transformer_model.generator(
108
+ decoder_output[:, -1, :] # Use output corresponding to the last input token
109
+ ) # [1, tgt_vocab_size]
110
+
111
+ # Find the most likely next token (greedy choice)
112
+ # prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
113
+ # _, next_word_id_tensor = torch.max(prob, dim=1)
114
+ next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
115
+ next_word_id = next_word_id_tensor.item()
116
+
117
+ # Append the chosen token to the sequence
118
+ ys = torch.cat(
119
+ [ys, torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id)],
120
+ dim=1
121
+ ) # [1, current_len + 1]
122
+
123
+ # Stop if EOS token is generated
124
+ if next_word_id == eos_idx:
125
  break
126
 
127
+ # Return the generated sequence (excluding the initial SOS token)
128
+ return ys[:, 1:] # Shape [1, generated_len]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  except RuntimeError as e:
131
+ logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
132
  if "CUDA out of memory" in str(e) and device.type == "cuda":
133
  gc.collect()
134
  torch.cuda.empty_cache()
135
+ return torch.empty((1, 0), dtype=torch.long, device=device) # Return empty tensor on error
136
  except Exception as e:
137
+ logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
138
+ return torch.empty((1, 0), dtype=torch.long, device=device)
139
 
140
 
141
+ # --- Translation Function (Using Greedy Decode) ---
142
  def translate(
143
  model: pl.LightningModule,
144
  src_sentence: str,
 
149
  sos_idx: int,
150
  eos_idx: int,
151
  pad_idx: int,
152
+ ) -> str: # Returns a single string
 
 
 
153
  """
154
+ Translates a single SMILES string using greedy decoding.
 
155
  """
156
  model.eval() # Ensure model is in eval mode
 
 
157
 
158
  # --- Tokenize Source ---
159
  try:
160
  # Ensure tokenizer has truncation/padding configured if needed, or handle manually
161
+ smiles_tokenizer.enable_truncation(max_length=max_len) # Use max_len for source truncation too
162
  src_encoded = smiles_tokenizer.encode(src_sentence)
163
  if not src_encoded or not src_encoded.ids:
164
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
165
+ return "[Encoding Error]"
166
  # Use the truncated IDs directly
167
  src_ids = src_encoded.ids
 
168
  except Exception as e:
169
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
170
+ return "[Encoding Error]"
171
 
172
  # --- Prepare Input Tensor and Mask ---
173
  src = (
174
  torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
175
  ) # [1, src_len]
176
+ # Create padding mask (True where it's a pad token, should be all False here unless tokenizer pads)
177
  src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
178
 
179
+ # --- Perform Greedy Decoding ---
180
+ # Calls the greedy_decode function defined *above in this file*
181
  # Note: max_len for generation should come from config if it dictates output length
182
  generation_max_len = config.get(
183
  "max_len", 256
184
  ) # Use config max_len for output limit
185
+ tgt_tokens_tensor = greedy_decode(
186
  model=model,
187
  src=src,
188
  src_padding_mask=src_padding_mask,
189
  max_len=generation_max_len, # Use generation limit
190
  sos_idx=sos_idx,
191
  eos_idx=eos_idx,
192
+ # pad_idx=pad_idx, # Not needed by greedy_decode internal loop
193
  device=device,
194
+ ) # Returns a single tensor [1, generated_len]
 
 
 
195
 
196
  # --- Decode Generated Tokens ---
197
+ if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
198
+ logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}")
199
+ return "[Decoding Error - Empty Output]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
202
+ try:
203
+ # Decode using the target tokenizer, skipping special tokens
204
+ translation = iupac_tokenizer.decode(
205
+ tgt_tokens, skip_special_tokens=True
206
+ )
207
+ return translation
208
+ except Exception as e:
209
+ logging.error(
210
+ f"Error decoding target tokens {tgt_tokens}: {e}",
211
+ exc_info=True,
212
+ )
213
+ return "[Decoding Error]"
214
 
215
 
216
+ # --- Model/Tokenizer Loading Function (Unchanged) ---
217
  def load_model_and_tokenizers():
218
  """Loads tokenizers, config, and model from Hugging Face Hub."""
219
  global model, smiles_tokenizer, iupac_tokenizer, device, config
 
223
 
224
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
225
  try:
226
+ # Determine device
 
227
  if torch.cuda.is_available():
228
  logging.warning(
229
  "CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
230
  )
231
  device = torch.device("cpu")
 
232
  # device = torch.device("cuda")
233
  # logging.info("CUDA available, using GPU.")
234
  else:
235
  device = torch.device("cpu")
236
  logging.info("CUDA not available, using CPU.")
237
 
238
+ # Download files
239
  logging.info("Downloading files from Hugging Face Hub...")
240
  try:
241
+ cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
 
 
 
242
  os.makedirs(cache_dir, exist_ok=True)
243
  logging.info(f"Using cache directory: {cache_dir}")
244
 
 
261
  logging.info("Files downloaded successfully.")
262
  except Exception as e:
263
  logging.error(
264
+ f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
265
  exc_info=True,
266
  )
267
  raise gr.Error(
 
275
  config = json.load(f)
276
  logging.info("Configuration loaded.")
277
  # --- Validate essential config keys ---
 
 
 
278
  required_keys = [
279
+ "src_vocab_size", # Use the key saved in config
280
+ "tgt_vocab_size", # Use the key saved in config
281
+ "emb_size", "nhead", "ffn_hid_dim", "num_encoder_layers",
282
+ "num_decoder_layers", "dropout", "max_len",
283
+ "pad_token_id", "bos_token_id", "eos_token_id",
 
 
 
 
 
 
 
 
 
 
 
284
  ]
285
+ # Remap if needed (example shown, adjust if your keys differ)
286
  config_key_mapping = {
287
+ "src_vocab_size": config.get("src_vocab_size", config.get("actual_src_vocab_size")),
288
+ "tgt_vocab_size": config.get("tgt_vocab_size", config.get("actual_tgt_vocab_size")),
289
+ # Add other mappings if necessary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  }
 
291
  config.update(config_key_mapping)
292
 
293
  missing_keys = [key for key in required_keys if config.get(key) is None]
294
  if missing_keys:
295
+ raise ValueError(
296
+ f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
297
+ f"Ensure these were saved in the hyperparameters during training."
298
+ )
299
+
 
 
 
 
 
 
 
 
 
 
300
  logging.info(
301
+ f"Using config values: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
302
  f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
303
  f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
304
  )
305
 
306
  except FileNotFoundError:
307
+ logging.error(f"Config file not found: {config_path}")
308
+ raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
 
 
 
 
309
  except json.JSONDecodeError as e:
310
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
311
+ raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Error: {e}")
312
+ except ValueError as e:
 
 
313
  logging.error(f"Config validation error: {e}")
314
  raise gr.Error(f"Config Error: {e}")
315
+ except Exception as e:
316
+ logging.error(f"Unexpected error loading config: {e}", exc_info=True)
317
+ raise gr.Error(f"Config Error: Unexpected error. Check logs. Error: {e}")
 
 
 
 
318
 
319
  # Load tokenizers
320
  logging.info("Loading tokenizers...")
 
322
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
323
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
324
  logging.info("Tokenizers loaded.")
325
+ # --- Optional: Validate Tokenizer Special Tokens Against Config ---
326
+ # (Keep validation as before, it's still useful)
327
  pad_token = "<pad>"
328
  sos_token = "<sos>"
329
  eos_token = "<eos>"
330
  unk_token = "<unk>"
 
331
  issues = []
332
+ # ... (keep the validation checks as in the original code) ...
333
+ if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]: issues.append(f"SMILES PAD ID mismatch")
334
+ if smiles_tokenizer.token_to_id(unk_token) is None: issues.append("SMILES UNK token not found")
335
+ if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]: issues.append(f"IUPAC PAD ID mismatch")
336
+ if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]: issues.append(f"IUPAC SOS ID mismatch")
337
+ if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]: issues.append(f"IUPAC EOS ID mismatch")
338
+ if iupac_tokenizer.token_to_id(unk_token) is None: issues.append("IUPAC UNK token not found")
339
+ if issues: logging.warning("Tokenizer validation issues: " + "; ".join(issues))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  except Exception as e:
342
+ logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
343
+ raise gr.Error(f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}")
 
 
 
 
 
344
 
345
  # Load model
346
  logging.info("Loading model from checkpoint...")
347
  try:
348
+ # Use the vocab sizes and hparams from the loaded config
 
349
  model = SmilesIupacLitModule.load_from_checkpoint(
350
  checkpoint_path,
351
+ # Ensure these match the keys used when saving hparams
352
  src_vocab_size=config["src_vocab_size"],
353
  tgt_vocab_size=config["tgt_vocab_size"],
354
+ # Pass the whole config dict, load_from_checkpoint will pick what it needs
355
+ # if the keys match the __init__ args or are in hparams
356
+ **config, # Pass loaded config directly as keyword args
357
+ map_location=device,
358
+ strict=True, # Start strict, set to False if encountering key errors
359
  )
360
 
 
361
  model.to(device)
362
  model.eval()
363
+ model.freeze()
364
  logging.info(
365
  f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
366
  )
367
 
368
  except FileNotFoundError:
369
+ logging.error(f"Checkpoint file not found: {checkpoint_path}")
370
+ raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
 
 
 
 
371
  except Exception as e:
372
+ logging.error(f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
 
 
 
 
373
  if "size mismatch" in str(e):
374
+ error_detail = (f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint.")
 
 
 
375
  logging.error(error_detail)
376
  raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
377
  elif "memory" in str(e).lower():
378
+ logging.warning("Potential OOM error during model loading.")
379
+ gc.collect(); torch.cuda.empty_cache() if device.type == "cuda" else None
380
+ raise gr.Error(f"Model Error: OOM loading model. Check Space resources. Error: {e}")
 
 
 
 
381
  else:
382
+ raise gr.Error(f"Model Error: Failed to load checkpoint. Check logs. Error: {e}")
 
 
383
 
384
+ except gr.Error: raise
385
+ except Exception as e:
386
+ logging.error(f"Unexpected error during loading: {e}", exc_info=True)
387
+ raise gr.Error(f"Initialization Error: Unexpected error. Check logs. Error: {e}")
 
 
 
 
 
388
 
389
 
390
+ # --- Inference Function for Gradio (Simplified) ---
391
+ def predict_iupac(smiles_string):
392
  """
393
+ Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
 
394
  """
395
  global model, smiles_tokenizer, iupac_tokenizer, device, config
396
 
397
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
398
  error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
399
  logging.error(error_msg)
400
+ return f"Error: {error_msg}" # Return single error string
 
 
 
 
 
401
 
402
  if not smiles_string or not smiles_string.strip():
403
  error_msg = "Error: Please enter a valid SMILES string."
404
+ return f"Error: {error_msg}" # Return single error string
 
 
 
 
405
 
406
  smiles_input = smiles_string.strip()
407
 
 
408
  try:
409
+ # --- Call the core translation logic (greedy) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  sos_idx = config["bos_token_id"]
411
  eos_idx = config["eos_token_id"]
412
  pad_idx = config["pad_token_id"]
413
+ gen_max_len = config["max_len"]
414
 
415
+ predicted_name = translate( # Returns a single string now
416
  model=model,
417
  src_sentence=smiles_input,
418
  smiles_tokenizer=smiles_tokenizer,
419
  iupac_tokenizer=iupac_tokenizer,
420
  device=device,
421
+ max_len=gen_max_len,
422
  sos_idx=sos_idx,
423
  eos_idx=eos_idx,
424
  pad_idx=pad_idx,
 
 
 
425
  )
426
+ logging.info(f"Prediction returned: {predicted_name}")
427
 
428
  # --- Format Output ---
429
+ if "[Error]" in predicted_name: # Check for error messages from translate
430
+ output_text = f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
431
+ elif not predicted_name:
432
+ output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
433
  else:
 
 
434
  output_text = (
435
  f"Input SMILES: {smiles_input}\n\n"
436
+ f"Predicted IUPAC Name (Greedy Decode):\n"
437
+ f"{predicted_name}"
 
 
438
  )
 
 
 
 
439
  return output_text
440
 
441
  except RuntimeError as e:
442
  logging.error(f"Runtime error during translation: {e}", exc_info=True)
443
  error_msg = f"Runtime Error during translation: {e}"
444
  if "memory" in str(e).lower():
445
+ gc.collect(); torch.cuda.empty_cache() if device.type == "cuda" else None
446
+ error_msg += " (Potential OOM)"
447
+ return f"Error: {error_msg}" # Return single error string
 
 
 
448
 
449
  except Exception as e:
450
  logging.error(f"Unexpected error during translation: {e}", exc_info=True)
451
  error_msg = f"Unexpected Error during translation: {e}"
452
+ return f"Error: {error_msg}" # Return single error string
453
 
454
 
455
  # --- Load Model on App Start ---
 
 
456
  try:
457
  load_model_and_tokenizers()
458
  except gr.Error as ge:
459
  logging.error(f"Gradio Initialization Error: {ge}")
460
+ pass # Allow Gradio to potentially start with an error message
 
 
461
  except Exception as e:
462
+ logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
463
+ # Optionally raise gr.Error here too
 
 
 
 
464
 
465
 
466
+ # --- Create Gradio Interface (Simplified) ---
467
+ title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
468
  description = f"""
469
  Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
470
+ Translation uses **greedy decoding** (picks the most likely next word at each step).
471
+ **Note:** Model loaded on **{str(device).upper() if device else 'N/A'}**. Performance may vary. Check `config.json` in the repo for model details.
472
  """
473
 
474
+ # Define examples - remove beam search parameters
475
  examples = [
476
+ ["CCO"], # Ethanol
477
+ ["C1=CC=CC=C1"], # Benzene
478
+ ["CC(=O)Oc1ccccc1C(=O)O"], # Aspirin
479
+ ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"], # Ibuprofen
480
+ ["INVALID_SMILES"], # Example of invalid input
 
 
481
  ]
482
 
483
+ # Input component
484
  smiles_input = gr.Textbox(
485
  label="SMILES String",
486
  placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
487
  lines=1,
488
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ # Output component
491
  output_text = gr.Textbox(
492
+ label="Predicted IUPAC Name", lines=3, show_copy_button=True # Reduced lines slightly
493
  )
494
 
495
  # Create the interface instance
496
  iface = gr.Interface(
497
  fn=predict_iupac, # The function to call
498
+ inputs=smiles_input, # Single input component
 
 
 
 
499
  outputs=output_text, # Output component
500
  title=title,
501
  description=description,
502
+ examples=examples, # Examples with only SMILES input
503
+ allow_flagging="never",
504
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
505
  article="""
506
  **Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
507
+ Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations. Greedy decoding can sometimes get stuck in repetitive loops or produce suboptimal results compared to beam search.
 
508
  """,
 
 
509
  )
510
 
511
  # --- Launch the App ---
512
  if __name__ == "__main__":
513
+ iface.launch()