AdrianM0 commited on
Commit
fe32912
·
verified ·
1 Parent(s): 1c3db49

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +266 -155
app.py CHANGED
@@ -1,23 +1,27 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
- import torch.nn.functional as F # Needed for beam search log_softmax
5
- import pytorch_lightning as pl # Needed for LightningModule and loading
6
  import os
7
  import json
8
  import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
- import gc # For garbage collection on potential OOM
12
- import math # Potentially needed by imported classes
13
 
14
  # --- Configuration ---
15
  # Ensure these match the files uploaded to your Hugging Face Hub repository
16
- MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
17
- CHECKPOINT_FILENAME = "last.ckpt" # Or "best_model.ckpt" or whatever you uploaded
 
 
18
  SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
19
  IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
20
- CONFIG_FILENAME = "config.json" # Assumes you saved hparams to config.json during/after training
 
 
21
  # --- End Configuration ---
22
 
23
  # --- Logging ---
@@ -30,6 +34,7 @@ try:
30
  # We need the LightningModule definition and the mask function
31
  # Ensure enhanced_trainer.py is present in the root of your HF Repo
32
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
 
33
  logging.info("Successfully imported from enhanced_trainer.py.")
34
 
35
  # REMOVED: Redundant import from test_ckpt as functions are defined below
@@ -59,6 +64,7 @@ iupac_tokenizer: Tokenizer | None = None
59
  device: torch.device | None = None
60
  config: dict | None = None
61
 
 
62
  # --- Beam Search Decoding Logic (Locally defined) ---
63
  def beam_search_decode(
64
  model: pl.LightningModule,
@@ -77,8 +83,8 @@ def beam_search_decode(
77
  Performs beam search decoding using the LightningModule's model.
78
  (Ensures this code is self-contained within app.py or correctly imported)
79
  """
80
- model.eval() # Ensure model is in evaluation mode
81
- transformer_model = model.model # Access the underlying Seq2SeqTransformer
82
  n_best = min(n_best, beam_width)
83
 
84
  try:
@@ -86,15 +92,15 @@ def beam_search_decode(
86
  # --- Encode Source ---
87
  memory = transformer_model.encode(
88
  src, src_padding_mask
89
- ) # [1, src_len, emb_size]
90
  memory = memory.to(device)
91
- memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
92
 
93
  # --- Initialize Beams ---
94
  initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
95
  sos_idx
96
- ) # [1, 1]
97
- initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
98
  active_beams = [(initial_beam_seq, initial_beam_score)]
99
  finished_beams = []
100
 
@@ -107,20 +113,20 @@ def beam_search_decode(
107
  for current_seq, current_score in active_beams:
108
  # Check if the beam already ended
109
  if current_seq[0, -1].item() == eos_idx:
110
- # If already finished, add directly to finished beams and skip expansion
111
  finished_beams.append((current_seq, current_score))
112
  continue
113
 
114
  # Prepare inputs for the decoder
115
- tgt_input = current_seq # [1, current_len]
116
  tgt_seq_len = tgt_input.shape[1]
117
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
118
  device
119
- ) # [curr_len, curr_len]
120
  # No padding in target during generation yet
121
  tgt_padding_mask = torch.zeros(
122
  tgt_input.shape, dtype=torch.bool, device=device
123
- ) # [1, curr_len]
124
 
125
  # Decode one step
126
  decoder_output = transformer_model.decode(
@@ -129,35 +135,41 @@ def beam_search_decode(
129
  tgt_mask=tgt_mask,
130
  tgt_padding_mask=tgt_padding_mask,
131
  memory_key_padding_mask=memory_key_padding_mask,
132
- ) # [1, curr_len, emb_size]
133
 
134
  # Get logits for the *next* token prediction
135
  next_token_logits = transformer_model.generator(
136
- decoder_output[:, -1, :] # Use output corresponding to the last input token
137
- ) # [1, tgt_vocab_size]
 
 
138
 
139
  # Calculate log probabilities and add current beam score
140
  log_probs = F.log_softmax(
141
  next_token_logits, dim=-1
142
- ) # [1, tgt_vocab_size]
143
- combined_scores = log_probs + current_score # Add score of the current path
 
 
144
 
145
  # Find top k candidates for the *next* step
146
  topk_log_probs, topk_indices = torch.topk(
147
  combined_scores, beam_width, dim=-1
148
- ) # [1, beam_width], [1, beam_width]
149
 
150
  # Expand potential beams
151
  for i in range(beam_width):
152
  next_token_id = topk_indices[0, i].item()
153
  # Score is the cumulative log probability of the new sequence
154
- next_score = topk_log_probs[0, i].reshape(1) # Keep as tensor [1]
 
 
155
  next_token_tensor = torch.tensor(
156
  [[next_token_id]], dtype=torch.long, device=device
157
- ) # [1, 1]
158
  new_seq = torch.cat(
159
  [current_seq, next_token_tensor], dim=1
160
- ) # [1, current_len + 1]
161
  potential_next_beams.append((new_seq, next_score))
162
 
163
  # --- Prune Beams ---
@@ -166,26 +178,30 @@ def beam_search_decode(
166
 
167
  # Select the top `beam_width` beams for the next iteration
168
  active_beams = []
169
- temp_finished_beams = [] # Collect beams finished in *this* step
170
  for seq, score in potential_next_beams:
171
- if len(active_beams) >= beam_width and len(temp_finished_beams) >= beam_width:
172
- break # Optimization: Stop if we have enough active and finished candidates
 
 
 
173
 
174
  is_finished = seq[0, -1].item() == eos_idx
175
  if is_finished:
176
- # Add to temporary finished list for this step
177
- if len(temp_finished_beams) < beam_width:
178
- temp_finished_beams.append((seq, score))
179
  elif len(active_beams) < beam_width:
180
- # Add to active beams for next step
181
- active_beams.append((seq, score))
182
 
183
  # Add the newly finished beams to the main finished list
184
  finished_beams.extend(temp_finished_beams)
185
  # Optional: Prune finished_beams if it grows too large (e.g., keep top 2*beam_width)
186
  finished_beams.sort(key=lambda x: x[1].item(), reverse=True)
187
- finished_beams = finished_beams[:beam_width * 2] # Keep a reasonable number
188
-
 
189
 
190
  # --- Final Selection ---
191
  # Add any remaining active beams (which didn't finish) to the finished list
@@ -200,29 +216,36 @@ def beam_search_decode(
200
  return score.item()
201
  else:
202
  # Length penalty calculation
203
- penalty = ((5.0 + float(seq_len)) / 6.0) ** length_penalty # Common formula
 
 
204
  return score.item() / penalty
205
  # Alternative simpler penalty:
206
  # return score.item() / (float(seq_len) ** length_penalty)
207
 
208
- finished_beams.sort(key=get_score_with_penalty, reverse=True) # Higher score is better
 
 
209
 
210
  # Return the top n_best sequences (excluding the initial SOS token)
211
  top_sequences = [
212
- seq[:, 1:] for seq, score in finished_beams[:n_best] if seq.shape[1] > 1 # Ensure seq not just SOS
213
- ] # seq shape [1, len] -> [1, len-1]
 
 
214
  return top_sequences
215
 
216
  except RuntimeError as e:
217
  logging.error(f"Runtime error during beam search decode: {e}", exc_info=True)
218
- if "CUDA out of memory" in str(e) and device.type == 'cuda':
219
  gc.collect()
220
  torch.cuda.empty_cache()
221
- return [] # Return empty list on error
222
  except Exception as e:
223
  logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
224
  return []
225
 
 
226
  # --- Translation Function (Locally defined) ---
227
  def translate(
228
  model: pl.LightningModule,
@@ -242,9 +265,9 @@ def translate(
242
  Translates a single SMILES string using beam search.
243
  (Ensures this code is self-contained within app.py or correctly imported)
244
  """
245
- model.eval() # Ensure model is in eval mode
246
  translations = []
247
- n_best = min(n_best, beam_width) # Can't return more than beam width
248
 
249
  # --- Tokenize Source ---
250
  try:
@@ -264,19 +287,21 @@ def translate(
264
  # --- Prepare Input Tensor and Mask ---
265
  src = (
266
  torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
267
- ) # [1, src_len]
268
  # Create padding mask (True where it's a pad token, should be all False here)
269
- src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
270
 
271
  # --- Perform Beam Search Decoding ---
272
  # Calls the beam_search_decode function defined *above in this file*
273
  # Note: max_len for generation should come from config if it dictates output length
274
- generation_max_len = config.get("max_len", 256) # Use config max_len for output limit
 
 
275
  tgt_tokens_list = beam_search_decode(
276
  model=model,
277
  src=src,
278
  src_padding_mask=src_padding_mask,
279
- max_len=generation_max_len, # Use generation limit
280
  sos_idx=sos_idx,
281
  eos_idx=eos_idx,
282
  pad_idx=pad_idx,
@@ -284,7 +309,7 @@ def translate(
284
  beam_width=beam_width,
285
  n_best=n_best,
286
  length_penalty=length_penalty,
287
- ) # Returns list of tensors
288
 
289
  # --- Decode Generated Tokens ---
290
  if not tgt_tokens_list:
@@ -302,10 +327,15 @@ def translate(
302
  )
303
  translations.append(translation)
304
  except Exception as e:
305
- logging.error(f"Error decoding target tokens {tgt_tokens} for beam {i}: {e}", exc_info=True)
 
 
 
306
  translations.append("[Decoding Error]")
307
  else:
308
- logging.warning(f"Beam {i} result was empty or None for SMILES: {src_sentence}")
 
 
309
  translations.append("[Decoding Error - Empty Tensor]")
310
 
311
  # Pad with error messages if fewer than n_best results were generated
@@ -314,11 +344,12 @@ def translate(
314
 
315
  return translations
316
 
 
317
  # --- Model/Tokenizer Loading Function ---
318
  def load_model_and_tokenizers():
319
  """Loads tokenizers, config, and model from Hugging Face Hub."""
320
  global model, smiles_tokenizer, iupac_tokenizer, device, config
321
- if model is not None: # Already loaded
322
  logging.info("Model and tokenizers already loaded.")
323
  return
324
 
@@ -327,21 +358,24 @@ def load_model_and_tokenizers():
327
  # Determine device - Use CPU for Gradio Spaces unless GPU is explicitly available and desired
328
  # For simplicity and broader compatibility on free tier Spaces, CPU is safer.
329
  if torch.cuda.is_available():
330
- logging.warning("CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended.")
331
- device = torch.device("cpu")
332
- # Uncomment below and comment above line to try using GPU if available
333
- # device = torch.device("cuda")
334
- # logging.info("CUDA available, using GPU.")
 
 
335
  else:
336
- device = torch.device("cpu")
337
- logging.info("CUDA not available, using CPU.")
338
-
339
 
340
  # Download files from HF Hub
341
  logging.info("Downloading files from Hugging Face Hub...")
342
  try:
343
  # Use cache directory for Spaces persistence if possible
344
- cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache") # Gradio sets cache dir
 
 
345
  os.makedirs(cache_dir, exist_ok=True)
346
  logging.info(f"Using cache directory: {cache_dir}")
347
 
@@ -349,10 +383,14 @@ def load_model_and_tokenizers():
349
  repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
350
  )
351
  smiles_tokenizer_path = hf_hub_download(
352
- repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir
 
 
353
  )
354
  iupac_tokenizer_path = hf_hub_download(
355
- repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir
 
 
356
  )
357
  config_path = hf_hub_download(
358
  repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
@@ -379,8 +417,8 @@ def load_model_and_tokenizers():
379
  # Mappings might be needed if keys in config.json differ from these exact names
380
  required_keys = [
381
  # Need vocab sizes used during *training* for loading
382
- "actual_src_vocab_size", # Assuming this was saved in hparams
383
- "actual_tgt_vocab_size", # Assuming this was saved in hparams
384
  # Model architecture params
385
  "emb_size",
386
  "nhead",
@@ -388,17 +426,21 @@ def load_model_and_tokenizers():
388
  "num_encoder_layers",
389
  "num_decoder_layers",
390
  "dropout",
391
- "max_len", # Needed for generation limit and tokenizer setting
392
  # Special token IDs needed for generation
393
  # Assuming standard names, adjust if your config uses different keys
394
- "pad_token_id", # Often 0
395
- "bos_token_id", # Often 1 (used as SOS)
396
- "eos_token_id", # Often 2
397
  ]
398
  # Remap keys if necessary (e.g., if config.json uses 'src_vocab_size' instead of 'actual_src_vocab_size')
399
  config_key_mapping = {
400
- "actual_src_vocab_size": config.get("actual_src_vocab_size", config.get("src_vocab_size")),
401
- "actual_tgt_vocab_size": config.get("actual_tgt_vocab_size", config.get("tgt_vocab_size")),
 
 
 
 
402
  "emb_size": config.get("emb_size"),
403
  "nhead": config.get("nhead"),
404
  "ffn_hid_dim": config.get("ffn_hid_dim"),
@@ -406,9 +448,15 @@ def load_model_and_tokenizers():
406
  "num_decoder_layers": config.get("num_decoder_layers"),
407
  "dropout": config.get("dropout"),
408
  "max_len": config.get("max_len"),
409
- "pad_token_id": config.get("pad_token_id"), # Use default if missing? Risky.
410
- "bos_token_id": config.get("bos_token_id"), # Use default if missing? Risky.
411
- "eos_token_id": config.get("eos_token_id"), # Use default if missing? Risky.
 
 
 
 
 
 
412
  }
413
  # Update config with potentially remapped values
414
  config.update(config_key_mapping)
@@ -420,31 +468,44 @@ def load_model_and_tokenizers():
420
  # Re-check missing keys after attempting defaults
421
  missing_keys = [key for key in required_keys if config.get(key) is None]
422
  if missing_keys:
423
- raise ValueError(
424
  f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
425
  f"Ensure these were saved in the hyperparameters during training."
426
- )
427
  else:
428
- logging.warning(f"Config file was missing keys, used defaults for: {defaults_used}. This might be incorrect!")
 
 
429
 
430
  # Log the final config values being used
431
- logging.info(f"Using config values: src_vocab={config['actual_src_vocab_size']}, tgt_vocab={config['actual_tgt_vocab_size']}, "
432
- f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
433
- f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}")
 
 
434
 
435
  except FileNotFoundError:
436
- logging.error(f"Config file not found locally after download attempt: {config_path}")
437
- raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo.")
 
 
 
 
438
  except json.JSONDecodeError as e:
439
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
440
- raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}")
441
- except ValueError as e: # Catch our custom validation error
 
 
442
  logging.error(f"Config validation error: {e}")
443
  raise gr.Error(f"Config Error: {e}")
444
- except Exception as e: # Catch other potential errors during config processing
445
- logging.error(f"Unexpected error loading or validating config: {e}", exc_info=True)
446
- raise gr.Error(f"Config Error: Unexpected error processing config. Check logs. Error: {e}")
447
-
 
 
 
448
 
449
  # Load tokenizers
450
  logging.info("Loading tokenizers...")
@@ -453,7 +514,7 @@ def load_model_and_tokenizers():
453
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
454
  logging.info("Tokenizers loaded.")
455
 
456
- # --- Validate Tokenizer Special Tokens Against Config ---
457
  pad_token = "<pad>"
458
  sos_token = "<sos>"
459
  eos_token = "<eos>"
@@ -461,23 +522,33 @@ def load_model_and_tokenizers():
461
 
462
  issues = []
463
  if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
464
- issues.append(f"SMILES PAD ID mismatch (tokenizer={smiles_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})")
 
 
465
  if smiles_tokenizer.token_to_id(unk_token) is None:
466
  issues.append("SMILES UNK token not found")
467
 
468
  if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
469
- issues.append(f"IUPAC PAD ID mismatch (tokenizer={iupac_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})")
 
 
470
  if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
471
- issues.append(f"IUPAC SOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(sos_token)}, config={config['bos_token_id']})")
 
 
472
  if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
473
- issues.append(f"IUPAC EOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(eos_token)}, config={config['eos_token_id']})")
 
 
474
  if iupac_tokenizer.token_to_id(unk_token) is None:
475
  issues.append("IUPAC UNK token not found")
476
 
477
  if issues:
478
- logging.warning("Tokenizer validation issues detected: " + "; ".join(issues))
479
- # Decide if this is fatal or just a warning
480
- # raise gr.Error("Tokenizer Error: Special token IDs mismatch config. Check tokenizers and config.json.") # Make it fatal if IDs must match
 
 
481
 
482
  except Exception as e:
483
  logging.error(
@@ -499,46 +570,62 @@ def load_model_and_tokenizers():
499
  # Ensure these keys exist in your loaded 'config' dict after validation/mapping
500
  src_vocab_size=config["actual_src_vocab_size"],
501
  tgt_vocab_size=config["actual_tgt_vocab_size"],
502
- hparams_dict=config, # Pass the loaded config as hparams
503
- map_location=device, # Map model to the chosen device (CPU or CUDA)
504
- strict=False, # Be less strict about matching keys, useful for PTL versions or minor changes
505
  # REMOVED invalid argument: device="cpu",
506
  )
507
 
508
  # Ensure model is on the correct device, in eval mode, and frozen
509
  model.to(device)
510
  model.eval()
511
- model.freeze() # Disables gradient calculations
512
  logging.info(
513
  f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
514
  )
515
 
516
  except FileNotFoundError:
517
- logging.error(f"Checkpoint file not found locally after download attempt: {checkpoint_path}")
518
- raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
 
 
 
 
519
  except Exception as e:
520
  logging.error(
521
- f"Error loading model from checkpoint {checkpoint_path}: {e}", exc_info=True
 
522
  )
523
  # Check for common errors
524
  if "size mismatch" in str(e):
525
- error_detail = (f"Potential size mismatch. Check if vocab sizes in config.json ({config.get('actual_src_vocab_size')}, "
526
- f"{config.get('actual_tgt_vocab_size')}) match the loaded checkpoint's embedding layers.")
527
- logging.error(error_detail)
528
- raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
 
 
529
  elif "memory" in str(e).lower():
530
  logging.warning("Potential Out-of-Memory error during model loading.")
531
  gc.collect()
532
- if device.type == 'cuda': torch.cuda.empty_cache()
533
- raise gr.Error(f"Model Error: Out of memory loading model. Check Space resources. Error: {e}")
 
 
 
534
  else:
535
- raise gr.Error(f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}")
 
 
536
 
537
- except gr.Error: # Re-raise Gradio errors to be displayed
538
  raise
539
- except Exception as e: # Catch any other unexpected errors
540
- logging.error(f"Unexpected error during model/tokenizer loading: {e}", exc_info=True)
541
- raise gr.Error(f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}")
 
 
 
 
542
 
543
 
544
  # --- Inference Function for Gradio ---
@@ -553,14 +640,18 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
553
  error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
554
  logging.error(error_msg)
555
  # Try to determine n_best for error output formatting
556
- try: n_best_int = int(n_best_str)
557
- except: n_best_int = 1
 
 
558
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
559
 
560
  if not smiles_string or not smiles_string.strip():
561
  error_msg = "Error: Please enter a valid SMILES string."
562
- try: n_best_int = int(n_best_str)
563
- except: n_best_int = 1
 
 
564
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
565
 
566
  smiles_input = smiles_string.strip()
@@ -571,10 +662,14 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
571
  n_best = int(n_best_str)
572
  length_penalty = float(length_penalty_str)
573
  if beam_width < 1 or n_best < 1 or n_best > beam_width:
574
- raise ValueError("Beam width and n_best must be >= 1, and n_best <= beam width.")
 
 
575
  if length_penalty < 0:
576
- logging.warning(f"Length penalty {length_penalty} is negative, using 0.0 instead.")
577
- length_penalty = 0.0
 
 
578
  except ValueError as e:
579
  error_msg = f"Error: Invalid input parameter ({e}). Please check beam width, n_best, and length penalty values."
580
  logging.error(error_msg)
@@ -588,10 +683,10 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
588
  try:
589
  # --- Call the core translation logic ---
590
  # Retrieve necessary IDs from the loaded config
591
- sos_idx = config['bos_token_id']
592
- eos_idx = config['eos_token_id']
593
- pad_idx = config['pad_token_id']
594
- gen_max_len = config['max_len'] # Max length for generation
595
 
596
  predicted_names = translate(
597
  model=model,
@@ -599,7 +694,7 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
599
  smiles_tokenizer=smiles_tokenizer,
600
  iupac_tokenizer=iupac_tokenizer,
601
  device=device,
602
- max_len=gen_max_len, # Pass generation length limit
603
  sos_idx=sos_idx,
604
  eos_idx=eos_idx,
605
  pad_idx=pad_idx,
@@ -615,15 +710,16 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
615
  else:
616
  # Ensure we only display up to n_best results, even if translate returned more/fewer due to errors
617
  display_names = predicted_names[:n_best]
618
- output_text = (f"Input SMILES: {smiles_input}\n\n"
619
- f"Top {len(display_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n")
 
 
620
  output_text += "\n".join(
621
  [f"{i + 1}. {name}" for i, name in enumerate(display_names)]
622
  )
623
  # Add a note if fewer results than requested were generated
624
  if len(display_names) < n_best:
625
- output_text += f"\n\nNote: Only {len(display_names)} result(s) generated successfully."
626
-
627
 
628
  return output_text
629
 
@@ -632,7 +728,8 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
632
  error_msg = f"Runtime Error during translation: {e}"
633
  if "memory" in str(e).lower():
634
  gc.collect()
635
- if device.type == 'cuda': torch.cuda.empty_cache()
 
636
  error_msg += " (Potential OOM - try reducing beam width or input length)"
637
  # Return n_best error messages
638
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
@@ -649,13 +746,15 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
649
  try:
650
  load_model_and_tokenizers()
651
  except gr.Error as ge:
652
- logging.error(f"Gradio Initialization Error: {ge}")
653
- # Gradio handles displaying gr.Error, but we log it too.
654
- # We might want to display a placeholder UI or message if loading fails critically.
655
- pass # Allow Gradio to potentially start with an error message
656
  except Exception as e:
657
  # Catch any non-Gradio errors during the initial load sequence
658
- logging.error(f"Critical error during initial model loading sequence: {e}", exc_info=True)
 
 
659
  # Optionally raise gr.Error here too, although it might be too late if Gradio hasn't fully initialized.
660
  # raise gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
661
 
@@ -670,13 +769,13 @@ Translation uses beam search decoding. Adjust parameters below.
670
 
671
  # Define examples using the input types expected by the interface
672
  examples = [
673
- ["CCO", 5, 3, 0.6], # Ethanol
674
- ["C1=CC=CC=C1", 5, 3, 0.6], # Benzene
675
- ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
676
- ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
677
  # Very complex example - might take time or fail on CPU/low memory
678
  # ["CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=C(C(=N4)C5=CC=CC=C5)C", 8, 1, 0.7], # Gleevec (Imatinib) - simplified SMILES structure
679
- ["INVALID_SMILES", 3, 1, 0.6], # Example of invalid input
680
  ]
681
 
682
  # Ensure input components match the `predict_iupac` function signature order and types
@@ -687,16 +786,28 @@ smiles_input = gr.Textbox(
687
  )
688
  # Use number inputs for sliders if direct type casting is desired, but sliders often return float/int anyway
689
  beam_width_input = gr.Slider(
690
- minimum=1, maximum=10, value=5, step=1, label="Beam Width (k)",
691
- info="Number of sequences kept at each step (higher = more exploration, slower). Affects memory usage."
 
 
 
 
692
  )
693
  n_best_input = gr.Slider(
694
- minimum=1, maximum=10, value=3, step=1, label="Number of Results (n_best)",
695
- info="How many top sequences to return (must be <= Beam Width)."
 
 
 
 
696
  )
697
  length_penalty_input = gr.Slider(
698
- minimum=0.0, maximum=2.0, value=0.6, step=0.1, label="Length Penalty (alpha)",
699
- info="Controls preference for sequence length. >1 favors longer, <1 favors shorter, 0 no penalty."
 
 
 
 
700
  )
701
  output_text = gr.Textbox(
702
  label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
@@ -704,19 +815,19 @@ output_text = gr.Textbox(
704
 
705
  # Create the interface instance
706
  iface = gr.Interface(
707
- fn=predict_iupac, # The function to call
708
- inputs=[ # List of input components
709
  smiles_input,
710
  beam_width_input,
711
  n_best_input,
712
- length_penalty_input
713
  ],
714
- outputs=output_text, # Output component
715
  title=title,
716
  description=description,
717
- examples=examples, # Examples to populate the interface
718
- allow_flagging="never", # Disable flagging
719
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"), # Optional theme
720
  article="""
721
  **Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
722
  Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations.
@@ -733,4 +844,4 @@ if __name__ == "__main__":
733
  # Set share=False or remove for deployment on Spaces.
734
  # Use server_name="0.0.0.0" to make it accessible on the network if running locally
735
  # Use auth=("username", "password") for basic authentication
736
- iface.launch() # share=True is deprecated, use launch()
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+ import torch.nn.functional as F # Needed for beam search log_softmax
5
+ import pytorch_lightning as pl # Needed for LightningModule and loading
6
  import os
7
  import json
8
  import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
+ import gc # For garbage collection on potential OOM
12
+ import math # Potentially needed by imported classes
13
 
14
  # --- Configuration ---
15
  # Ensure these match the files uploaded to your Hugging Face Hub repository
16
+ MODEL_REPO_ID = (
17
+ "AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
18
+ )
19
+ CHECKPOINT_FILENAME = "last.ckpt" # Or "best_model.ckpt" or whatever you uploaded
20
  SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
21
  IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
22
+ CONFIG_FILENAME = (
23
+ "config.json" # Assumes you saved hparams to config.json during/after training
24
+ )
25
  # --- End Configuration ---
26
 
27
  # --- Logging ---
 
34
  # We need the LightningModule definition and the mask function
35
  # Ensure enhanced_trainer.py is present in the root of your HF Repo
36
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
37
+
38
  logging.info("Successfully imported from enhanced_trainer.py.")
39
 
40
  # REMOVED: Redundant import from test_ckpt as functions are defined below
 
64
  device: torch.device | None = None
65
  config: dict | None = None
66
 
67
+
68
  # --- Beam Search Decoding Logic (Locally defined) ---
69
  def beam_search_decode(
70
  model: pl.LightningModule,
 
83
  Performs beam search decoding using the LightningModule's model.
84
  (Ensures this code is self-contained within app.py or correctly imported)
85
  """
86
+ model.eval() # Ensure model is in evaluation mode
87
+ transformer_model = model.model # Access the underlying Seq2SeqTransformer
88
  n_best = min(n_best, beam_width)
89
 
90
  try:
 
92
  # --- Encode Source ---
93
  memory = transformer_model.encode(
94
  src, src_padding_mask
95
+ ) # [1, src_len, emb_size]
96
  memory = memory.to(device)
97
+ memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
98
 
99
  # --- Initialize Beams ---
100
  initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
101
  sos_idx
102
+ ) # [1, 1]
103
+ initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
104
  active_beams = [(initial_beam_seq, initial_beam_score)]
105
  finished_beams = []
106
 
 
113
  for current_seq, current_score in active_beams:
114
  # Check if the beam already ended
115
  if current_seq[0, -1].item() == eos_idx:
116
+ # If already finished, add directly to finished beams and skip expansion
117
  finished_beams.append((current_seq, current_score))
118
  continue
119
 
120
  # Prepare inputs for the decoder
121
+ tgt_input = current_seq # [1, current_len]
122
  tgt_seq_len = tgt_input.shape[1]
123
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
124
  device
125
+ ) # [curr_len, curr_len]
126
  # No padding in target during generation yet
127
  tgt_padding_mask = torch.zeros(
128
  tgt_input.shape, dtype=torch.bool, device=device
129
+ ) # [1, curr_len]
130
 
131
  # Decode one step
132
  decoder_output = transformer_model.decode(
 
135
  tgt_mask=tgt_mask,
136
  tgt_padding_mask=tgt_padding_mask,
137
  memory_key_padding_mask=memory_key_padding_mask,
138
+ ) # [1, curr_len, emb_size]
139
 
140
  # Get logits for the *next* token prediction
141
  next_token_logits = transformer_model.generator(
142
+ decoder_output[
143
+ :, -1, :
144
+ ] # Use output corresponding to the last input token
145
+ ) # [1, tgt_vocab_size]
146
 
147
  # Calculate log probabilities and add current beam score
148
  log_probs = F.log_softmax(
149
  next_token_logits, dim=-1
150
+ ) # [1, tgt_vocab_size]
151
+ combined_scores = (
152
+ log_probs + current_score
153
+ ) # Add score of the current path
154
 
155
  # Find top k candidates for the *next* step
156
  topk_log_probs, topk_indices = torch.topk(
157
  combined_scores, beam_width, dim=-1
158
+ ) # [1, beam_width], [1, beam_width]
159
 
160
  # Expand potential beams
161
  for i in range(beam_width):
162
  next_token_id = topk_indices[0, i].item()
163
  # Score is the cumulative log probability of the new sequence
164
+ next_score = topk_log_probs[0, i].reshape(
165
+ 1
166
+ ) # Keep as tensor [1]
167
  next_token_tensor = torch.tensor(
168
  [[next_token_id]], dtype=torch.long, device=device
169
+ ) # [1, 1]
170
  new_seq = torch.cat(
171
  [current_seq, next_token_tensor], dim=1
172
+ ) # [1, current_len + 1]
173
  potential_next_beams.append((new_seq, next_score))
174
 
175
  # --- Prune Beams ---
 
178
 
179
  # Select the top `beam_width` beams for the next iteration
180
  active_beams = []
181
+ temp_finished_beams = [] # Collect beams finished in *this* step
182
  for seq, score in potential_next_beams:
183
+ if (
184
+ len(active_beams) >= beam_width
185
+ and len(temp_finished_beams) >= beam_width
186
+ ):
187
+ break # Optimization: Stop if we have enough active and finished candidates
188
 
189
  is_finished = seq[0, -1].item() == eos_idx
190
  if is_finished:
191
+ # Add to temporary finished list for this step
192
+ if len(temp_finished_beams) < beam_width:
193
+ temp_finished_beams.append((seq, score))
194
  elif len(active_beams) < beam_width:
195
+ # Add to active beams for next step
196
+ active_beams.append((seq, score))
197
 
198
  # Add the newly finished beams to the main finished list
199
  finished_beams.extend(temp_finished_beams)
200
  # Optional: Prune finished_beams if it grows too large (e.g., keep top 2*beam_width)
201
  finished_beams.sort(key=lambda x: x[1].item(), reverse=True)
202
+ finished_beams = finished_beams[
203
+ : beam_width * 2
204
+ ] # Keep a reasonable number
205
 
206
  # --- Final Selection ---
207
  # Add any remaining active beams (which didn't finish) to the finished list
 
216
  return score.item()
217
  else:
218
  # Length penalty calculation
219
+ penalty = (
220
+ (5.0 + float(seq_len)) / 6.0
221
+ ) ** length_penalty # Common formula
222
  return score.item() / penalty
223
  # Alternative simpler penalty:
224
  # return score.item() / (float(seq_len) ** length_penalty)
225
 
226
+ finished_beams.sort(
227
+ key=get_score_with_penalty, reverse=True
228
+ ) # Higher score is better
229
 
230
  # Return the top n_best sequences (excluding the initial SOS token)
231
  top_sequences = [
232
+ seq[:, 1:]
233
+ for seq, score in finished_beams[:n_best]
234
+ if seq.shape[1] > 1 # Ensure seq not just SOS
235
+ ] # seq shape [1, len] -> [1, len-1]
236
  return top_sequences
237
 
238
  except RuntimeError as e:
239
  logging.error(f"Runtime error during beam search decode: {e}", exc_info=True)
240
+ if "CUDA out of memory" in str(e) and device.type == "cuda":
241
  gc.collect()
242
  torch.cuda.empty_cache()
243
+ return [] # Return empty list on error
244
  except Exception as e:
245
  logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
246
  return []
247
 
248
+
249
  # --- Translation Function (Locally defined) ---
250
  def translate(
251
  model: pl.LightningModule,
 
265
  Translates a single SMILES string using beam search.
266
  (Ensures this code is self-contained within app.py or correctly imported)
267
  """
268
+ model.eval() # Ensure model is in eval mode
269
  translations = []
270
+ n_best = min(n_best, beam_width) # Can't return more than beam width
271
 
272
  # --- Tokenize Source ---
273
  try:
 
287
  # --- Prepare Input Tensor and Mask ---
288
  src = (
289
  torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
290
+ ) # [1, src_len]
291
  # Create padding mask (True where it's a pad token, should be all False here)
292
+ src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
293
 
294
  # --- Perform Beam Search Decoding ---
295
  # Calls the beam_search_decode function defined *above in this file*
296
  # Note: max_len for generation should come from config if it dictates output length
297
+ generation_max_len = config.get(
298
+ "max_len", 256
299
+ ) # Use config max_len for output limit
300
  tgt_tokens_list = beam_search_decode(
301
  model=model,
302
  src=src,
303
  src_padding_mask=src_padding_mask,
304
+ max_len=generation_max_len, # Use generation limit
305
  sos_idx=sos_idx,
306
  eos_idx=eos_idx,
307
  pad_idx=pad_idx,
 
309
  beam_width=beam_width,
310
  n_best=n_best,
311
  length_penalty=length_penalty,
312
+ ) # Returns list of tensors
313
 
314
  # --- Decode Generated Tokens ---
315
  if not tgt_tokens_list:
 
327
  )
328
  translations.append(translation)
329
  except Exception as e:
330
+ logging.error(
331
+ f"Error decoding target tokens {tgt_tokens} for beam {i}: {e}",
332
+ exc_info=True,
333
+ )
334
  translations.append("[Decoding Error]")
335
  else:
336
+ logging.warning(
337
+ f"Beam {i} result was empty or None for SMILES: {src_sentence}"
338
+ )
339
  translations.append("[Decoding Error - Empty Tensor]")
340
 
341
  # Pad with error messages if fewer than n_best results were generated
 
344
 
345
  return translations
346
 
347
+
348
  # --- Model/Tokenizer Loading Function ---
349
  def load_model_and_tokenizers():
350
  """Loads tokenizers, config, and model from Hugging Face Hub."""
351
  global model, smiles_tokenizer, iupac_tokenizer, device, config
352
+ if model is not None: # Already loaded
353
  logging.info("Model and tokenizers already loaded.")
354
  return
355
 
 
358
  # Determine device - Use CPU for Gradio Spaces unless GPU is explicitly available and desired
359
  # For simplicity and broader compatibility on free tier Spaces, CPU is safer.
360
  if torch.cuda.is_available():
361
+ logging.warning(
362
+ "CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
363
+ )
364
+ device = torch.device("cpu")
365
+ # Uncomment below and comment above line to try using GPU if available
366
+ # device = torch.device("cuda")
367
+ # logging.info("CUDA available, using GPU.")
368
  else:
369
+ device = torch.device("cpu")
370
+ logging.info("CUDA not available, using CPU.")
 
371
 
372
  # Download files from HF Hub
373
  logging.info("Downloading files from Hugging Face Hub...")
374
  try:
375
  # Use cache directory for Spaces persistence if possible
376
+ cache_dir = os.environ.get(
377
+ "GRADIO_CACHE", "./hf_cache"
378
+ ) # Gradio sets cache dir
379
  os.makedirs(cache_dir, exist_ok=True)
380
  logging.info(f"Using cache directory: {cache_dir}")
381
 
 
383
  repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
384
  )
385
  smiles_tokenizer_path = hf_hub_download(
386
+ repo_id=MODEL_REPO_ID,
387
+ filename=SMILES_TOKENIZER_FILENAME,
388
+ cache_dir=cache_dir,
389
  )
390
  iupac_tokenizer_path = hf_hub_download(
391
+ repo_id=MODEL_REPO_ID,
392
+ filename=IUPAC_TOKENIZER_FILENAME,
393
+ cache_dir=cache_dir,
394
  )
395
  config_path = hf_hub_download(
396
  repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
 
417
  # Mappings might be needed if keys in config.json differ from these exact names
418
  required_keys = [
419
  # Need vocab sizes used during *training* for loading
420
+ "actual_src_vocab_size", # Assuming this was saved in hparams
421
+ "actual_tgt_vocab_size", # Assuming this was saved in hparams
422
  # Model architecture params
423
  "emb_size",
424
  "nhead",
 
426
  "num_encoder_layers",
427
  "num_decoder_layers",
428
  "dropout",
429
+ "max_len", # Needed for generation limit and tokenizer setting
430
  # Special token IDs needed for generation
431
  # Assuming standard names, adjust if your config uses different keys
432
+ "pad_token_id", # Often 0
433
+ "bos_token_id", # Often 1 (used as SOS)
434
+ "eos_token_id", # Often 2
435
  ]
436
  # Remap keys if necessary (e.g., if config.json uses 'src_vocab_size' instead of 'actual_src_vocab_size')
437
  config_key_mapping = {
438
+ "actual_src_vocab_size": config.get(
439
+ "actual_src_vocab_size", config.get("src_vocab_size")
440
+ ),
441
+ "actual_tgt_vocab_size": config.get(
442
+ "actual_tgt_vocab_size", config.get("tgt_vocab_size")
443
+ ),
444
  "emb_size": config.get("emb_size"),
445
  "nhead": config.get("nhead"),
446
  "ffn_hid_dim": config.get("ffn_hid_dim"),
 
448
  "num_decoder_layers": config.get("num_decoder_layers"),
449
  "dropout": config.get("dropout"),
450
  "max_len": config.get("max_len"),
451
+ "pad_token_id": config.get(
452
+ "pad_token_id"
453
+ ), # Use default if missing? Risky.
454
+ "bos_token_id": config.get(
455
+ "bos_token_id"
456
+ ), # Use default if missing? Risky.
457
+ "eos_token_id": config.get(
458
+ "eos_token_id"
459
+ ), # Use default if missing? Risky.
460
  }
461
  # Update config with potentially remapped values
462
  config.update(config_key_mapping)
 
468
  # Re-check missing keys after attempting defaults
469
  missing_keys = [key for key in required_keys if config.get(key) is None]
470
  if missing_keys:
471
+ raise ValueError(
472
  f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
473
  f"Ensure these were saved in the hyperparameters during training."
474
+ )
475
  else:
476
+ logging.warning(
477
+ f"Config file was missing keys, used defaults for: {defaults_used}. This might be incorrect!"
478
+ )
479
 
480
  # Log the final config values being used
481
+ logging.info(
482
+ f"Using config values: src_vocab={config['actual_src_vocab_size']}, tgt_vocab={config['actual_tgt_vocab_size']}, "
483
+ f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
484
+ f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
485
+ )
486
 
487
  except FileNotFoundError:
488
+ logging.error(
489
+ f"Config file not found locally after download attempt: {config_path}"
490
+ )
491
+ raise gr.Error(
492
+ f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo."
493
+ )
494
  except json.JSONDecodeError as e:
495
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
496
+ raise gr.Error(
497
+ f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}"
498
+ )
499
+ except ValueError as e: # Catch our custom validation error
500
  logging.error(f"Config validation error: {e}")
501
  raise gr.Error(f"Config Error: {e}")
502
+ except Exception as e: # Catch other potential errors during config processing
503
+ logging.error(
504
+ f"Unexpected error loading or validating config: {e}", exc_info=True
505
+ )
506
+ raise gr.Error(
507
+ f"Config Error: Unexpected error processing config. Check logs. Error: {e}"
508
+ )
509
 
510
  # Load tokenizers
511
  logging.info("Loading tokenizers...")
 
514
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
515
  logging.info("Tokenizers loaded.")
516
 
517
+ # --- Validate Tokenizer Special Tokens Against Config ---
518
  pad_token = "<pad>"
519
  sos_token = "<sos>"
520
  eos_token = "<eos>"
 
522
 
523
  issues = []
524
  if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
525
+ issues.append(
526
+ f"SMILES PAD ID mismatch (tokenizer={smiles_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})"
527
+ )
528
  if smiles_tokenizer.token_to_id(unk_token) is None:
529
  issues.append("SMILES UNK token not found")
530
 
531
  if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
532
+ issues.append(
533
+ f"IUPAC PAD ID mismatch (tokenizer={iupac_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})"
534
+ )
535
  if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
536
+ issues.append(
537
+ f"IUPAC SOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(sos_token)}, config={config['bos_token_id']})"
538
+ )
539
  if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
540
+ issues.append(
541
+ f"IUPAC EOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(eos_token)}, config={config['eos_token_id']})"
542
+ )
543
  if iupac_tokenizer.token_to_id(unk_token) is None:
544
  issues.append("IUPAC UNK token not found")
545
 
546
  if issues:
547
+ logging.warning(
548
+ "Tokenizer validation issues detected: " + "; ".join(issues)
549
+ )
550
+ # Decide if this is fatal or just a warning
551
+ # raise gr.Error("Tokenizer Error: Special token IDs mismatch config. Check tokenizers and config.json.") # Make it fatal if IDs must match
552
 
553
  except Exception as e:
554
  logging.error(
 
570
  # Ensure these keys exist in your loaded 'config' dict after validation/mapping
571
  src_vocab_size=config["actual_src_vocab_size"],
572
  tgt_vocab_size=config["actual_tgt_vocab_size"],
573
+ hparams_dict=config, # Pass the loaded config as hparams
574
+ map_location=device, # Map model to the chosen device (CPU or CUDA)
575
+ strict=False, # Be less strict about matching keys, useful for PTL versions or minor changes
576
  # REMOVED invalid argument: device="cpu",
577
  )
578
 
579
  # Ensure model is on the correct device, in eval mode, and frozen
580
  model.to(device)
581
  model.eval()
582
+ model.freeze() # Disables gradient calculations
583
  logging.info(
584
  f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
585
  )
586
 
587
  except FileNotFoundError:
588
+ logging.error(
589
+ f"Checkpoint file not found locally after download attempt: {checkpoint_path}"
590
+ )
591
+ raise gr.Error(
592
+ f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
593
+ )
594
  except Exception as e:
595
  logging.error(
596
+ f"Error loading model from checkpoint {checkpoint_path}: {e}",
597
+ exc_info=True,
598
  )
599
  # Check for common errors
600
  if "size mismatch" in str(e):
601
+ error_detail = (
602
+ f"Potential size mismatch. Check if vocab sizes in config.json ({config.get('actual_src_vocab_size')}, "
603
+ f"{config.get('actual_tgt_vocab_size')}) match the loaded checkpoint's embedding layers."
604
+ )
605
+ logging.error(error_detail)
606
+ raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
607
  elif "memory" in str(e).lower():
608
  logging.warning("Potential Out-of-Memory error during model loading.")
609
  gc.collect()
610
+ if device.type == "cuda":
611
+ torch.cuda.empty_cache()
612
+ raise gr.Error(
613
+ f"Model Error: Out of memory loading model. Check Space resources. Error: {e}"
614
+ )
615
  else:
616
+ raise gr.Error(
617
+ f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}"
618
+ )
619
 
620
+ except gr.Error: # Re-raise Gradio errors to be displayed
621
  raise
622
+ except Exception as e: # Catch any other unexpected errors
623
+ logging.error(
624
+ f"Unexpected error during model/tokenizer loading: {e}", exc_info=True
625
+ )
626
+ raise gr.Error(
627
+ f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}"
628
+ )
629
 
630
 
631
  # --- Inference Function for Gradio ---
 
640
  error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
641
  logging.error(error_msg)
642
  # Try to determine n_best for error output formatting
643
+ try:
644
+ n_best_int = int(n_best_str)
645
+ except:
646
+ n_best_int = 1
647
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
648
 
649
  if not smiles_string or not smiles_string.strip():
650
  error_msg = "Error: Please enter a valid SMILES string."
651
+ try:
652
+ n_best_int = int(n_best_str)
653
+ except:
654
+ n_best_int = 1
655
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
656
 
657
  smiles_input = smiles_string.strip()
 
662
  n_best = int(n_best_str)
663
  length_penalty = float(length_penalty_str)
664
  if beam_width < 1 or n_best < 1 or n_best > beam_width:
665
+ raise ValueError(
666
+ "Beam width and n_best must be >= 1, and n_best <= beam width."
667
+ )
668
  if length_penalty < 0:
669
+ logging.warning(
670
+ f"Length penalty {length_penalty} is negative, using 0.0 instead."
671
+ )
672
+ length_penalty = 0.0
673
  except ValueError as e:
674
  error_msg = f"Error: Invalid input parameter ({e}). Please check beam width, n_best, and length penalty values."
675
  logging.error(error_msg)
 
683
  try:
684
  # --- Call the core translation logic ---
685
  # Retrieve necessary IDs from the loaded config
686
+ sos_idx = config["bos_token_id"]
687
+ eos_idx = config["eos_token_id"]
688
+ pad_idx = config["pad_token_id"]
689
+ gen_max_len = config["max_len"] # Max length for generation
690
 
691
  predicted_names = translate(
692
  model=model,
 
694
  smiles_tokenizer=smiles_tokenizer,
695
  iupac_tokenizer=iupac_tokenizer,
696
  device=device,
697
+ max_len=gen_max_len, # Pass generation length limit
698
  sos_idx=sos_idx,
699
  eos_idx=eos_idx,
700
  pad_idx=pad_idx,
 
710
  else:
711
  # Ensure we only display up to n_best results, even if translate returned more/fewer due to errors
712
  display_names = predicted_names[:n_best]
713
+ output_text = (
714
+ f"Input SMILES: {smiles_input}\n\n"
715
+ f"Top {len(display_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
716
+ )
717
  output_text += "\n".join(
718
  [f"{i + 1}. {name}" for i, name in enumerate(display_names)]
719
  )
720
  # Add a note if fewer results than requested were generated
721
  if len(display_names) < n_best:
722
+ output_text += f"\n\nNote: Only {len(display_names)} result(s) generated successfully."
 
723
 
724
  return output_text
725
 
 
728
  error_msg = f"Runtime Error during translation: {e}"
729
  if "memory" in str(e).lower():
730
  gc.collect()
731
+ if device.type == "cuda":
732
+ torch.cuda.empty_cache()
733
  error_msg += " (Potential OOM - try reducing beam width or input length)"
734
  # Return n_best error messages
735
  return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
 
746
  try:
747
  load_model_and_tokenizers()
748
  except gr.Error as ge:
749
+ logging.error(f"Gradio Initialization Error: {ge}")
750
+ # Gradio handles displaying gr.Error, but we log it too.
751
+ # We might want to display a placeholder UI or message if loading fails critically.
752
+ pass # Allow Gradio to potentially start with an error message
753
  except Exception as e:
754
  # Catch any non-Gradio errors during the initial load sequence
755
+ logging.error(
756
+ f"Critical error during initial model loading sequence: {e}", exc_info=True
757
+ )
758
  # Optionally raise gr.Error here too, although it might be too late if Gradio hasn't fully initialized.
759
  # raise gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
760
 
 
769
 
770
  # Define examples using the input types expected by the interface
771
  examples = [
772
+ ["CCO", 5, 3, 0.6], # Ethanol
773
+ ["C1=CC=CC=C1", 5, 3, 0.6], # Benzene
774
+ ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
775
+ ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
776
  # Very complex example - might take time or fail on CPU/low memory
777
  # ["CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=C(C(=N4)C5=CC=CC=C5)C", 8, 1, 0.7], # Gleevec (Imatinib) - simplified SMILES structure
778
+ ["INVALID_SMILES", 3, 1, 0.6], # Example of invalid input
779
  ]
780
 
781
  # Ensure input components match the `predict_iupac` function signature order and types
 
786
  )
787
  # Use number inputs for sliders if direct type casting is desired, but sliders often return float/int anyway
788
  beam_width_input = gr.Slider(
789
+ minimum=1,
790
+ maximum=10,
791
+ value=5,
792
+ step=1,
793
+ label="Beam Width (k)",
794
+ info="Number of sequences kept at each step (higher = more exploration, slower). Affects memory usage.",
795
  )
796
  n_best_input = gr.Slider(
797
+ minimum=1,
798
+ maximum=10,
799
+ value=3,
800
+ step=1,
801
+ label="Number of Results (n_best)",
802
+ info="How many top sequences to return (must be <= Beam Width).",
803
  )
804
  length_penalty_input = gr.Slider(
805
+ minimum=0.0,
806
+ maximum=2.0,
807
+ value=0.6,
808
+ step=0.1,
809
+ label="Length Penalty (alpha)",
810
+ info="Controls preference for sequence length. >1 favors longer, <1 favors shorter, 0 no penalty.",
811
  )
812
  output_text = gr.Textbox(
813
  label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
 
815
 
816
  # Create the interface instance
817
  iface = gr.Interface(
818
+ fn=predict_iupac, # The function to call
819
+ inputs=[ # List of input components
820
  smiles_input,
821
  beam_width_input,
822
  n_best_input,
823
+ length_penalty_input,
824
  ],
825
+ outputs=output_text, # Output component
826
  title=title,
827
  description=description,
828
+ examples=examples, # Examples to populate the interface
829
+ allow_flagging="never", # Disable flagging
830
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"), # Optional theme
831
  article="""
832
  **Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
833
  Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations.
 
844
  # Set share=False or remove for deployment on Spaces.
845
  # Use server_name="0.0.0.0" to make it accessible on the network if running locally
846
  # Use auth=("username", "password") for basic authentication
847
+ iface.launch() # share=True is deprecated, use launch()