AdrianM0 commited on
Commit
59543a5
·
verified ·
1 Parent(s): 96b0c1a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +233 -122
app.py CHANGED
@@ -1,15 +1,15 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
- import torch.nn.functional as F # <--- Added import
5
- import pytorch_lightning as pl # <--- Added import (needed for type hints, model access)
6
  import os
7
  import json
8
  import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
- import gc # For garbage collection on potential OOM
12
- import math # Needed for PositionalEncoding if moved here (or keep in enhanced_trainer)
13
 
14
  # --- Configuration ---
15
  MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator"
@@ -20,28 +20,39 @@ CONFIG_FILENAME = "config.json"
20
  # --- End Configuration ---
21
 
22
  # --- Logging ---
23
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
24
 
25
  # --- Load Helper Code (Only Model Definition Needed) ---
26
  try:
27
  # We only need the LightningModule definition and the mask function now
28
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
 
29
  logging.info("Successfully imported from enhanced_trainer.py.")
30
 
31
  # We will define beam_search_decode and translate locally in this file
32
  # REMOVED: from test_ckpt import beam_search_decode, translate
33
 
34
  except ImportError as e:
35
- logging.error(f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'.")
36
- gr.Error(f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}")
 
 
 
 
37
  exit()
38
  except Exception as e:
39
- logging.error(f"An unexpected error occurred during helper code import: {e}", exc_info=True)
40
- gr.Error(f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}")
 
 
 
 
41
  exit()
42
 
43
  # --- Global Variables (Load Model Once) ---
44
- model: pl.LightningModule | None = None # Added type hint
45
  smiles_tokenizer: Tokenizer | None = None
46
  iupac_tokenizer: Tokenizer | None = None
47
  device: torch.device | None = None
@@ -49,6 +60,7 @@ config: dict | None = None
49
 
50
  # --- Beam Search Decoding Logic (Moved from test_ckpt.py) ---
51
 
 
52
  def beam_search_decode(
53
  model: pl.LightningModule,
54
  src: torch.Tensor,
@@ -56,11 +68,11 @@ def beam_search_decode(
56
  max_len: int,
57
  sos_idx: int,
58
  eos_idx: int,
59
- pad_idx: int, # Needed for padding mask check if src has padding
60
  device: torch.device,
61
  beam_width: int = 5,
62
- n_best: int = 5, # Number of top sequences to return
63
- length_penalty: float = 0.6 # Alpha for length normalization (0=no penalty, 1=full penalty)
64
  ) -> list[torch.Tensor]:
65
  """
66
  Performs beam search decoding using the LightningModule's model.
@@ -68,20 +80,24 @@ def beam_search_decode(
68
  """
69
  # Ensure model is in eval mode (redundant if called after model.eval(), but safe)
70
  model.eval()
71
- transformer_model = model.model # Access the underlying Seq2SeqTransformer
72
- n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences
73
 
74
  try:
75
  with torch.no_grad():
76
  # --- Encode Source ---
77
- memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size]
 
 
78
  memory = memory.to(device)
79
  # Ensure memory_key_padding_mask is also on the correct device for decode
80
- memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
81
 
82
  # --- Initialize Beams ---
83
- initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) # [1, 1]
84
- initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
 
 
85
  active_beams = [(initial_beam_seq, initial_beam_score)]
86
  finished_beams = []
87
 
@@ -96,29 +112,45 @@ def beam_search_decode(
96
  finished_beams.append((current_seq, current_score))
97
  continue
98
 
99
- tgt_input = current_seq # [1, current_len]
100
  tgt_seq_len = tgt_input.shape[1]
101
- tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) # [curr_len, curr_len]
102
- tgt_padding_mask = torch.zeros(tgt_input.shape, dtype=torch.bool, device=device) # [1, curr_len]
 
 
 
 
103
 
104
  decoder_output = transformer_model.decode(
105
  tgt=tgt_input,
106
  memory=memory,
107
  tgt_mask=tgt_mask,
108
  tgt_padding_mask=tgt_padding_mask,
109
- memory_key_padding_mask=memory_key_padding_mask
110
- ) # [1, curr_len, emb_size]
111
 
112
- next_token_logits = transformer_model.generator(decoder_output[:, -1, :]) # [1, tgt_vocab_size]
113
- log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size]
 
 
 
 
114
 
115
- topk_log_probs, topk_indices = torch.topk(log_probs + current_score, beam_width, dim=-1)
 
 
116
 
117
  for i in range(beam_width):
118
  next_token_id = topk_indices[0, i].item()
119
- next_score = topk_log_probs[0, i].reshape(1) # Keep as tensor [1]
120
- next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=device) # [1, 1]
121
- new_seq = torch.cat([current_seq, next_token_tensor], dim=1) # [1, current_len + 1]
 
 
 
 
 
 
122
  potential_next_beams.append((new_seq, next_score))
123
 
124
  potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
@@ -126,14 +158,14 @@ def beam_search_decode(
126
  active_beams = []
127
  added_count = 0
128
  for seq, score in potential_next_beams:
129
- is_finished = seq[0, -1].item() == eos_idx
130
- if is_finished:
131
- finished_beams.append((seq, score))
132
- elif added_count < beam_width:
133
- active_beams.append((seq, score))
134
- added_count += 1
135
- elif added_count >= beam_width:
136
- break
137
 
138
  finished_beams.extend(active_beams)
139
 
@@ -148,22 +180,27 @@ def beam_search_decode(
148
  # Ensure seq_len is float for pow
149
  return score.item() / (float(seq_len) ** length_penalty)
150
 
151
- finished_beams.sort(key=get_score, reverse=True) # Higher score is better
152
 
153
- top_sequences = [seq[:, 1:] for seq, score in finished_beams[:n_best]] # seq shape [1, len] -> [1, len-1]
 
 
154
  return top_sequences
155
 
156
  except RuntimeError as e:
157
  logging.error(f"Runtime error during beam search decode: {e}")
158
  if "CUDA out of memory" in str(e):
159
- gc.collect(); torch.cuda.empty_cache()
160
- return [] # Return empty list on error
 
161
  except Exception as e:
162
  logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
163
  return []
164
 
 
165
  # --- Translation Function (Moved from test_ckpt.py) ---
166
 
 
167
  def translate(
168
  model: pl.LightningModule,
169
  src_sentence: str,
@@ -176,13 +213,13 @@ def translate(
176
  pad_idx: int,
177
  beam_width: int = 5,
178
  n_best: int = 5,
179
- length_penalty: float = 0.6
180
  ) -> list[str]:
181
  """
182
  Translates a single SMILES string using beam search.
183
  (Code copied and pasted from test_ckpt.py)
184
  """
185
- model.eval() # Ensure model is in eval mode
186
  translations = []
187
 
188
  # --- Tokenize Source ---
@@ -191,17 +228,19 @@ def translate(
191
  if not src_encoded or not src_encoded.ids:
192
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
193
  return ["[Encoding Error]"] * n_best
194
- src_ids = src_encoded.ids[:max_len] # Truncate source
195
  if not src_ids:
196
- logging.warning(f"Source empty after truncation: {src_sentence}")
197
- return ["[Encoding Error - Empty Src]"] * n_best
198
  except Exception as e:
199
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
200
  return ["[Encoding Error]"] * n_best
201
 
202
  # --- Prepare Input Tensor and Mask ---
203
- src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len]
204
- src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
 
 
205
 
206
  # --- Perform Beam Search Decoding ---
207
  # Calls the beam_search_decode function defined above in this file
@@ -216,19 +255,21 @@ def translate(
216
  device=device,
217
  beam_width=beam_width,
218
  n_best=n_best,
219
- length_penalty=length_penalty
220
- ) # Returns list of tensors
221
 
222
  # --- Decode Generated Tokens ---
223
  if not tgt_tokens_list:
224
- logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
225
- return ["[Decoding Error - Empty Output]"] * n_best
226
 
227
  for tgt_tokens_tensor in tgt_tokens_list:
228
  if tgt_tokens_tensor.numel() > 0:
229
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
230
  try:
231
- translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
 
 
232
  translations.append(translation)
233
  except Exception as e:
234
  logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
@@ -247,7 +288,7 @@ def translate(
247
  def load_model_and_tokenizers():
248
  """Loads tokenizers, config, and model from Hugging Face Hub."""
249
  global model, smiles_tokenizer, iupac_tokenizer, device, config
250
- if model is not None: # Already loaded
251
  logging.info("Model and tokenizers already loaded.")
252
  return
253
 
@@ -259,42 +300,71 @@ def load_model_and_tokenizers():
259
  # Download files from HF Hub
260
  logging.info("Downloading files from Hugging Face Hub...")
261
  try:
262
- checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME)
263
- smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME)
264
- iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME)
265
- config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME)
 
 
 
 
 
 
 
 
266
  logging.info("Files downloaded successfully.")
267
  except Exception as e:
268
- logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True)
269
- raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}")
 
 
 
 
 
270
 
271
  # Load config
272
  logging.info("Loading configuration...")
273
  try:
274
- with open(config_path, 'r') as f:
275
  config = json.load(f)
276
  logging.info("Configuration loaded.")
277
  # --- Validate essential config keys ---
278
  required_keys = [
279
- 'src_vocab_size', 'tgt_vocab_size', 'emb_size', 'nhead',
280
- 'ffn_hid_dim', 'num_encoder_layers', 'num_decoder_layers',
281
- 'dropout', 'max_len', 'bos_token_id', 'eos_token_id', 'pad_token_id'
 
 
 
 
 
 
 
 
 
282
  ]
283
  missing_keys = [key for key in required_keys if key not in config]
284
  if missing_keys:
285
- raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}")
 
 
286
  # --- End Validation ---
287
  except FileNotFoundError:
288
- logging.error(f"Config file not found locally after download attempt: {config_path}")
289
- raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo.")
 
 
 
 
290
  except json.JSONDecodeError as e:
291
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
292
- raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}")
 
 
293
  except ValueError as e:
294
  logging.error(f"Config validation error: {e}")
295
  raise gr.Error(f"Config Error: {e}")
296
 
297
-
298
  # Load tokenizers
299
  logging.info("Loading tokenizers...")
300
  try:
@@ -303,52 +373,80 @@ def load_model_and_tokenizers():
303
  logging.info("Tokenizers loaded.")
304
  # --- Validate Tokenizer Special Tokens ---
305
  # Add more robust checks if necessary
306
- if smiles_tokenizer.token_to_id("<pad>") != config['pad_token_id'] or \
307
- smiles_tokenizer.token_to_id("<unk>") is None:
308
- logging.warning("SMILES tokenizer special tokens might not match config or are missing.")
309
- if iupac_tokenizer.token_to_id("<pad>") != config['pad_token_id'] or \
310
- iupac_tokenizer.token_to_id("<sos>") != config['bos_token_id'] or \
311
- iupac_tokenizer.token_to_id("<eos>") != config['eos_token_id'] or \
312
- iupac_tokenizer.token_to_id("<unk>") is None:
313
- logging.warning("IUPAC tokenizer special tokens might not match config or are missing.")
 
 
 
 
 
 
 
 
314
  # --- End Validation ---
315
  except Exception as e:
316
- logging.error(f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}", exc_info=True)
317
- raise gr.Error(f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}")
 
 
 
 
 
318
 
319
  # Load model
320
  logging.info("Loading model from checkpoint...")
321
  try:
322
  model = SmilesIupacLitModule.load_from_checkpoint(
323
  checkpoint_path,
324
- src_vocab_size=config['src_vocab_size'],
325
- tgt_vocab_size=config['tgt_vocab_size'],
326
  map_location=device,
327
  hparams_dict=config,
328
  strict=False,
329
- device="cpu"
330
  )
331
  model.to(device)
332
  model.eval()
333
  model.freeze()
334
- logging.info("Model loaded successfully, set to eval mode, frozen, and moved to device.")
 
 
335
 
336
  except FileNotFoundError:
337
- logging.error(f"Checkpoint file not found locally after download attempt: {checkpoint_path}")
338
- raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
 
 
 
 
339
  except Exception as e:
340
- logging.error(f"Error loading model from checkpoint {checkpoint_path}: {e}", exc_info=True)
 
 
 
341
  if "memory" in str(e).lower():
342
  gc.collect()
343
  if device == torch.device("cuda"):
344
  torch.cuda.empty_cache()
345
- raise gr.Error(f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}")
 
 
346
 
347
  except gr.Error:
348
- raise
349
  except Exception as e:
350
- logging.error(f"Unexpected error during model/tokenizer loading: {e}", exc_info=True)
351
- raise gr.Error(f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}")
 
 
 
 
352
 
353
 
354
  # --- Inference Function for Gradio (Unchanged, calls local translate) ---
@@ -361,15 +459,19 @@ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
361
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
362
  error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs."
363
  # Ensure n_best is int for range, default to 1 if conversion fails early
364
- try: n_best_int = int(n_best)
365
- except: n_best_int = 1
366
- return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)])
 
 
367
 
368
  if not smiles_string or not smiles_string.strip():
369
  error_msg = "Error: Please enter a valid SMILES string."
370
- try: n_best_int = int(n_best)
371
- except: n_best_int = 1
372
- return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)])
 
 
373
 
374
  smiles_input = smiles_string.strip()
375
  try:
@@ -377,10 +479,12 @@ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
377
  n_best = int(n_best)
378
  length_penalty = float(length_penalty)
379
  except ValueError as e:
380
- error_msg = f"Error: Invalid input parameter type ({e})."
381
- return f"1. {error_msg}" # Cannot determine n_best here
382
 
383
- logging.info(f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})")
 
 
384
 
385
  try:
386
  # Calls the translate function defined *above in this file*
@@ -390,21 +494,23 @@ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
390
  smiles_tokenizer=smiles_tokenizer,
391
  iupac_tokenizer=iupac_tokenizer,
392
  device=device,
393
- max_len=config['max_len'],
394
- sos_idx=config['bos_token_id'],
395
- eos_idx=config['eos_token_id'],
396
- pad_idx=config['pad_token_id'],
397
  beam_width=beam_width,
398
  n_best=n_best,
399
- length_penalty=length_penalty
400
  )
401
  logging.info(f"Predictions returned: {predicted_names}")
402
 
403
  if not predicted_names:
404
- output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated."
405
  else:
406
  output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
407
- output_text += "\n".join([f"{i+1}. {name}" for i, name in enumerate(predicted_names)])
 
 
408
 
409
  return output_text
410
 
@@ -416,21 +522,23 @@ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
416
  if device == torch.device("cuda"):
417
  torch.cuda.empty_cache()
418
  error_msg += " (Potential OOM)"
419
- return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)])
420
 
421
  except Exception as e:
422
  logging.error(f"Unexpected error during translation: {e}", exc_info=True)
423
  error_msg = f"Unexpected Error during translation: {e}"
424
- return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)])
425
 
426
 
427
  # --- Load Model on App Start (Unchanged) ---
428
  try:
429
  load_model_and_tokenizers()
430
  except gr.Error:
431
- pass # Error already raised for Gradio UI
432
  except Exception as e:
433
- logging.error(f"Critical error during initial model loading sequence: {e}", exc_info=True)
 
 
434
  gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
435
 
436
 
@@ -445,16 +553,21 @@ Adjust beam search parameters below. Higher beam width explores more possibiliti
445
  examples = [
446
  ["CCO", 5, 3, 0.6],
447
  ["C1=CC=CC=C1", 5, 3, 0.6],
448
- ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
449
- ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
450
- ["CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C", 5, 1, 0.6], # Complex example
 
 
 
 
 
451
  ["INVALID_SMILES", 5, 1, 0.6],
452
  ]
453
 
454
  smiles_input = gr.Textbox(
455
  label="SMILES String",
456
  placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
457
- lines=1
458
  )
459
  beam_width_input = gr.Slider(
460
  minimum=1,
@@ -462,7 +575,7 @@ beam_width_input = gr.Slider(
462
  value=5,
463
  step=1,
464
  label="Beam Width (k)",
465
- info="Number of sequences to keep at each decoding step (higher = more exploration, slower)."
466
  )
467
  n_best_input = gr.Slider(
468
  minimum=1,
@@ -470,7 +583,7 @@ n_best_input = gr.Slider(
470
  value=3,
471
  step=1,
472
  label="Number of Results (n_best)",
473
- info="How many top-scoring sequences to return (must be <= Beam Width)."
474
  )
475
  length_penalty_input = gr.Slider(
476
  minimum=0.0,
@@ -478,12 +591,10 @@ length_penalty_input = gr.Slider(
478
  value=0.6,
479
  step=0.1,
480
  label="Length Penalty (alpha)",
481
- info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty."
482
  )
483
  output_text = gr.Textbox(
484
- label="Predicted IUPAC Name(s)",
485
- lines=5,
486
- show_copy_button=True
487
  )
488
 
489
  iface = gr.Interface(
@@ -495,7 +606,7 @@ iface = gr.Interface(
495
  examples=examples,
496
  allow_flagging="never",
497
  theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
498
- article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results."
499
  )
500
 
501
  # --- Launch the App (Unchanged) ---
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+ import torch.nn.functional as F # <--- Added import
5
+ import pytorch_lightning as pl # <--- Added import (needed for type hints, model access)
6
  import os
7
  import json
8
  import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
+ import gc # For garbage collection on potential OOM
12
+ import math # Needed for PositionalEncoding if moved here (or keep in enhanced_trainer)
13
 
14
  # --- Configuration ---
15
  MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator"
 
20
  # --- End Configuration ---
21
 
22
  # --- Logging ---
23
+ logging.basicConfig(
24
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
25
+ )
26
 
27
  # --- Load Helper Code (Only Model Definition Needed) ---
28
  try:
29
  # We only need the LightningModule definition and the mask function now
30
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
31
+
32
  logging.info("Successfully imported from enhanced_trainer.py.")
33
 
34
  # We will define beam_search_decode and translate locally in this file
35
  # REMOVED: from test_ckpt import beam_search_decode, translate
36
 
37
  except ImportError as e:
38
+ logging.error(
39
+ f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
40
+ )
41
+ gr.Error(
42
+ f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
43
+ )
44
  exit()
45
  except Exception as e:
46
+ logging.error(
47
+ f"An unexpected error occurred during helper code import: {e}", exc_info=True
48
+ )
49
+ gr.Error(
50
+ f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
51
+ )
52
  exit()
53
 
54
  # --- Global Variables (Load Model Once) ---
55
+ model: pl.LightningModule | None = None # Added type hint
56
  smiles_tokenizer: Tokenizer | None = None
57
  iupac_tokenizer: Tokenizer | None = None
58
  device: torch.device | None = None
 
60
 
61
  # --- Beam Search Decoding Logic (Moved from test_ckpt.py) ---
62
 
63
+
64
  def beam_search_decode(
65
  model: pl.LightningModule,
66
  src: torch.Tensor,
 
68
  max_len: int,
69
  sos_idx: int,
70
  eos_idx: int,
71
+ pad_idx: int, # Needed for padding mask check if src has padding
72
  device: torch.device,
73
  beam_width: int = 5,
74
+ n_best: int = 5, # Number of top sequences to return
75
+ length_penalty: float = 0.6, # Alpha for length normalization (0=no penalty, 1=full penalty)
76
  ) -> list[torch.Tensor]:
77
  """
78
  Performs beam search decoding using the LightningModule's model.
 
80
  """
81
  # Ensure model is in eval mode (redundant if called after model.eval(), but safe)
82
  model.eval()
83
+ transformer_model = model.model # Access the underlying Seq2SeqTransformer
84
+ n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences
85
 
86
  try:
87
  with torch.no_grad():
88
  # --- Encode Source ---
89
+ memory = transformer_model.encode(
90
+ src, src_padding_mask
91
+ ) # [1, src_len, emb_size]
92
  memory = memory.to(device)
93
  # Ensure memory_key_padding_mask is also on the correct device for decode
94
+ memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
95
 
96
  # --- Initialize Beams ---
97
+ initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
98
+ sos_idx
99
+ ) # [1, 1]
100
+ initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
101
  active_beams = [(initial_beam_seq, initial_beam_score)]
102
  finished_beams = []
103
 
 
112
  finished_beams.append((current_seq, current_score))
113
  continue
114
 
115
+ tgt_input = current_seq # [1, current_len]
116
  tgt_seq_len = tgt_input.shape[1]
117
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
118
+ device
119
+ ) # [curr_len, curr_len]
120
+ tgt_padding_mask = torch.zeros(
121
+ tgt_input.shape, dtype=torch.bool, device=device
122
+ ) # [1, curr_len]
123
 
124
  decoder_output = transformer_model.decode(
125
  tgt=tgt_input,
126
  memory=memory,
127
  tgt_mask=tgt_mask,
128
  tgt_padding_mask=tgt_padding_mask,
129
+ memory_key_padding_mask=memory_key_padding_mask,
130
+ ) # [1, curr_len, emb_size]
131
 
132
+ next_token_logits = transformer_model.generator(
133
+ decoder_output[:, -1, :]
134
+ ) # [1, tgt_vocab_size]
135
+ log_probs = F.log_softmax(
136
+ next_token_logits, dim=-1
137
+ ) # [1, tgt_vocab_size]
138
 
139
+ topk_log_probs, topk_indices = torch.topk(
140
+ log_probs + current_score, beam_width, dim=-1
141
+ )
142
 
143
  for i in range(beam_width):
144
  next_token_id = topk_indices[0, i].item()
145
+ next_score = topk_log_probs[0, i].reshape(
146
+ 1
147
+ ) # Keep as tensor [1]
148
+ next_token_tensor = torch.tensor(
149
+ [[next_token_id]], dtype=torch.long, device=device
150
+ ) # [1, 1]
151
+ new_seq = torch.cat(
152
+ [current_seq, next_token_tensor], dim=1
153
+ ) # [1, current_len + 1]
154
  potential_next_beams.append((new_seq, next_score))
155
 
156
  potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
 
158
  active_beams = []
159
  added_count = 0
160
  for seq, score in potential_next_beams:
161
+ is_finished = seq[0, -1].item() == eos_idx
162
+ if is_finished:
163
+ finished_beams.append((seq, score))
164
+ elif added_count < beam_width:
165
+ active_beams.append((seq, score))
166
+ added_count += 1
167
+ elif added_count >= beam_width:
168
+ break
169
 
170
  finished_beams.extend(active_beams)
171
 
 
180
  # Ensure seq_len is float for pow
181
  return score.item() / (float(seq_len) ** length_penalty)
182
 
183
+ finished_beams.sort(key=get_score, reverse=True) # Higher score is better
184
 
185
+ top_sequences = [
186
+ seq[:, 1:] for seq, score in finished_beams[:n_best]
187
+ ] # seq shape [1, len] -> [1, len-1]
188
  return top_sequences
189
 
190
  except RuntimeError as e:
191
  logging.error(f"Runtime error during beam search decode: {e}")
192
  if "CUDA out of memory" in str(e):
193
+ gc.collect()
194
+ torch.cuda.empty_cache()
195
+ return [] # Return empty list on error
196
  except Exception as e:
197
  logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
198
  return []
199
 
200
+
201
  # --- Translation Function (Moved from test_ckpt.py) ---
202
 
203
+
204
  def translate(
205
  model: pl.LightningModule,
206
  src_sentence: str,
 
213
  pad_idx: int,
214
  beam_width: int = 5,
215
  n_best: int = 5,
216
+ length_penalty: float = 0.6,
217
  ) -> list[str]:
218
  """
219
  Translates a single SMILES string using beam search.
220
  (Code copied and pasted from test_ckpt.py)
221
  """
222
+ model.eval() # Ensure model is in eval mode
223
  translations = []
224
 
225
  # --- Tokenize Source ---
 
228
  if not src_encoded or not src_encoded.ids:
229
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
230
  return ["[Encoding Error]"] * n_best
231
+ src_ids = src_encoded.ids[:max_len] # Truncate source
232
  if not src_ids:
233
+ logging.warning(f"Source empty after truncation: {src_sentence}")
234
+ return ["[Encoding Error - Empty Src]"] * n_best
235
  except Exception as e:
236
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
237
  return ["[Encoding Error]"] * n_best
238
 
239
  # --- Prepare Input Tensor and Mask ---
240
+ src = (
241
+ torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
242
+ ) # [1, src_len]
243
+ src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
244
 
245
  # --- Perform Beam Search Decoding ---
246
  # Calls the beam_search_decode function defined above in this file
 
255
  device=device,
256
  beam_width=beam_width,
257
  n_best=n_best,
258
+ length_penalty=length_penalty,
259
+ ) # Returns list of tensors
260
 
261
  # --- Decode Generated Tokens ---
262
  if not tgt_tokens_list:
263
+ logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
264
+ return ["[Decoding Error - Empty Output]"] * n_best
265
 
266
  for tgt_tokens_tensor in tgt_tokens_list:
267
  if tgt_tokens_tensor.numel() > 0:
268
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
269
  try:
270
+ translation = iupac_tokenizer.decode(
271
+ tgt_tokens, skip_special_tokens=True
272
+ )
273
  translations.append(translation)
274
  except Exception as e:
275
  logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
 
288
  def load_model_and_tokenizers():
289
  """Loads tokenizers, config, and model from Hugging Face Hub."""
290
  global model, smiles_tokenizer, iupac_tokenizer, device, config
291
+ if model is not None: # Already loaded
292
  logging.info("Model and tokenizers already loaded.")
293
  return
294
 
 
300
  # Download files from HF Hub
301
  logging.info("Downloading files from Hugging Face Hub...")
302
  try:
303
+ checkpoint_path = hf_hub_download(
304
+ repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME
305
+ )
306
+ smiles_tokenizer_path = hf_hub_download(
307
+ repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME
308
+ )
309
+ iupac_tokenizer_path = hf_hub_download(
310
+ repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME
311
+ )
312
+ config_path = hf_hub_download(
313
+ repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME
314
+ )
315
  logging.info("Files downloaded successfully.")
316
  except Exception as e:
317
+ logging.error(
318
+ f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
319
+ exc_info=True,
320
+ )
321
+ raise gr.Error(
322
+ f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}"
323
+ )
324
 
325
  # Load config
326
  logging.info("Loading configuration...")
327
  try:
328
+ with open(config_path, "r") as f:
329
  config = json.load(f)
330
  logging.info("Configuration loaded.")
331
  # --- Validate essential config keys ---
332
  required_keys = [
333
+ "src_vocab_size",
334
+ "tgt_vocab_size",
335
+ "emb_size",
336
+ "nhead",
337
+ "ffn_hid_dim",
338
+ "num_encoder_layers",
339
+ "num_decoder_layers",
340
+ "dropout",
341
+ "max_len",
342
+ "bos_token_id",
343
+ "eos_token_id",
344
+ "pad_token_id",
345
  ]
346
  missing_keys = [key for key in required_keys if key not in config]
347
  if missing_keys:
348
+ raise ValueError(
349
+ f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}"
350
+ )
351
  # --- End Validation ---
352
  except FileNotFoundError:
353
+ logging.error(
354
+ f"Config file not found locally after download attempt: {config_path}"
355
+ )
356
+ raise gr.Error(
357
+ f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo."
358
+ )
359
  except json.JSONDecodeError as e:
360
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
361
+ raise gr.Error(
362
+ f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}"
363
+ )
364
  except ValueError as e:
365
  logging.error(f"Config validation error: {e}")
366
  raise gr.Error(f"Config Error: {e}")
367
 
 
368
  # Load tokenizers
369
  logging.info("Loading tokenizers...")
370
  try:
 
373
  logging.info("Tokenizers loaded.")
374
  # --- Validate Tokenizer Special Tokens ---
375
  # Add more robust checks if necessary
376
+ if (
377
+ smiles_tokenizer.token_to_id("<pad>") != config["pad_token_id"]
378
+ or smiles_tokenizer.token_to_id("<unk>") is None
379
+ ):
380
+ logging.warning(
381
+ "SMILES tokenizer special tokens might not match config or are missing."
382
+ )
383
+ if (
384
+ iupac_tokenizer.token_to_id("<pad>") != config["pad_token_id"]
385
+ or iupac_tokenizer.token_to_id("<sos>") != config["bos_token_id"]
386
+ or iupac_tokenizer.token_to_id("<eos>") != config["eos_token_id"]
387
+ or iupac_tokenizer.token_to_id("<unk>") is None
388
+ ):
389
+ logging.warning(
390
+ "IUPAC tokenizer special tokens might not match config or are missing."
391
+ )
392
  # --- End Validation ---
393
  except Exception as e:
394
+ logging.error(
395
+ f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}",
396
+ exc_info=True,
397
+ )
398
+ raise gr.Error(
399
+ f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}"
400
+ )
401
 
402
  # Load model
403
  logging.info("Loading model from checkpoint...")
404
  try:
405
  model = SmilesIupacLitModule.load_from_checkpoint(
406
  checkpoint_path,
407
+ src_vocab_size=config["src_vocab_size"],
408
+ tgt_vocab_size=config["tgt_vocab_size"],
409
  map_location=device,
410
  hparams_dict=config,
411
  strict=False,
412
+ device="cpu",
413
  )
414
  model.to(device)
415
  model.eval()
416
  model.freeze()
417
+ logging.info(
418
+ "Model loaded successfully, set to eval mode, frozen, and moved to device."
419
+ )
420
 
421
  except FileNotFoundError:
422
+ logging.error(
423
+ f"Checkpoint file not found locally after download attempt: {checkpoint_path}"
424
+ )
425
+ raise gr.Error(
426
+ f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
427
+ )
428
  except Exception as e:
429
+ logging.error(
430
+ f"Error loading model from checkpoint {checkpoint_path}: {e}",
431
+ exc_info=True,
432
+ )
433
  if "memory" in str(e).lower():
434
  gc.collect()
435
  if device == torch.device("cuda"):
436
  torch.cuda.empty_cache()
437
+ raise gr.Error(
438
+ f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}"
439
+ )
440
 
441
  except gr.Error:
442
+ raise
443
  except Exception as e:
444
+ logging.error(
445
+ f"Unexpected error during model/tokenizer loading: {e}", exc_info=True
446
+ )
447
+ raise gr.Error(
448
+ f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}"
449
+ )
450
 
451
 
452
  # --- Inference Function for Gradio (Unchanged, calls local translate) ---
 
459
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
460
  error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs."
461
  # Ensure n_best is int for range, default to 1 if conversion fails early
462
+ try:
463
+ n_best_int = int(n_best)
464
+ except:
465
+ n_best_int = 1
466
+ return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
467
 
468
  if not smiles_string or not smiles_string.strip():
469
  error_msg = "Error: Please enter a valid SMILES string."
470
+ try:
471
+ n_best_int = int(n_best)
472
+ except:
473
+ n_best_int = 1
474
+ return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
475
 
476
  smiles_input = smiles_string.strip()
477
  try:
 
479
  n_best = int(n_best)
480
  length_penalty = float(length_penalty)
481
  except ValueError as e:
482
+ error_msg = f"Error: Invalid input parameter type ({e})."
483
+ return f"1. {error_msg}" # Cannot determine n_best here
484
 
485
+ logging.info(
486
+ f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})"
487
+ )
488
 
489
  try:
490
  # Calls the translate function defined *above in this file*
 
494
  smiles_tokenizer=smiles_tokenizer,
495
  iupac_tokenizer=iupac_tokenizer,
496
  device=device,
497
+ max_len=config["max_len"],
498
+ sos_idx=config["bos_token_id"],
499
+ eos_idx=config["eos_token_id"],
500
+ pad_idx=config["pad_token_id"],
501
  beam_width=beam_width,
502
  n_best=n_best,
503
+ length_penalty=length_penalty,
504
  )
505
  logging.info(f"Predictions returned: {predicted_names}")
506
 
507
  if not predicted_names:
508
+ output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated."
509
  else:
510
  output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
511
+ output_text += "\n".join(
512
+ [f"{i + 1}. {name}" for i, name in enumerate(predicted_names)]
513
+ )
514
 
515
  return output_text
516
 
 
522
  if device == torch.device("cuda"):
523
  torch.cuda.empty_cache()
524
  error_msg += " (Potential OOM)"
525
+ return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
526
 
527
  except Exception as e:
528
  logging.error(f"Unexpected error during translation: {e}", exc_info=True)
529
  error_msg = f"Unexpected Error during translation: {e}"
530
+ return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
531
 
532
 
533
  # --- Load Model on App Start (Unchanged) ---
534
  try:
535
  load_model_and_tokenizers()
536
  except gr.Error:
537
+ pass # Error already raised for Gradio UI
538
  except Exception as e:
539
+ logging.error(
540
+ f"Critical error during initial model loading sequence: {e}", exc_info=True
541
+ )
542
  gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
543
 
544
 
 
553
  examples = [
554
  ["CCO", 5, 3, 0.6],
555
  ["C1=CC=CC=C1", 5, 3, 0.6],
556
+ ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
557
+ ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
558
+ [
559
+ "CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C",
560
+ 5,
561
+ 1,
562
+ 0.6,
563
+ ], # Complex example
564
  ["INVALID_SMILES", 5, 1, 0.6],
565
  ]
566
 
567
  smiles_input = gr.Textbox(
568
  label="SMILES String",
569
  placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
570
+ lines=1,
571
  )
572
  beam_width_input = gr.Slider(
573
  minimum=1,
 
575
  value=5,
576
  step=1,
577
  label="Beam Width (k)",
578
+ info="Number of sequences to keep at each decoding step (higher = more exploration, slower).",
579
  )
580
  n_best_input = gr.Slider(
581
  minimum=1,
 
583
  value=3,
584
  step=1,
585
  label="Number of Results (n_best)",
586
+ info="How many top-scoring sequences to return (must be <= Beam Width).",
587
  )
588
  length_penalty_input = gr.Slider(
589
  minimum=0.0,
 
591
  value=0.6,
592
  step=0.1,
593
  label="Length Penalty (alpha)",
594
+ info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty.",
595
  )
596
  output_text = gr.Textbox(
597
+ label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
 
 
598
  )
599
 
600
  iface = gr.Interface(
 
606
  examples=examples,
607
  allow_flagging="never",
608
  theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
609
+ article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results.",
610
  )
611
 
612
  # --- Launch the App (Unchanged) ---