AdrianM0 commited on
Commit
f071a4c
·
verified ·
1 Parent(s): fa90b40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +485 -305
app.py CHANGED
@@ -1,8 +1,7 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
-
5
- # import torch.nn.functional as F # No longer needed for greedy decode directly
6
  import pytorch_lightning as pl
7
  import os
8
  import json
@@ -10,9 +9,10 @@ import logging
10
  from tokenizers import Tokenizer
11
  from huggingface_hub import hf_hub_download
12
  import gc
13
- from rdkit.Chem import CanonSmiles
14
  import spaces
15
-
 
16
 
17
  # --- Configuration ---
18
  MODEL_REPO_ID = (
@@ -31,6 +31,7 @@ logging.basicConfig(
31
 
32
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
33
  try:
 
34
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
35
 
36
  logging.info("Successfully imported from enhanced_trainer.py.")
@@ -58,7 +59,7 @@ device: torch.device | None = None
58
  config: dict | None = None
59
 
60
 
61
- # --- Greedy Decoding Logic (Locally defined) ---
62
  def greedy_decode(
63
  model: pl.LightningModule,
64
  src: torch.Tensor,
@@ -70,90 +71,223 @@ def greedy_decode(
70
  ) -> torch.Tensor:
71
  """
72
  Performs greedy decoding using the LightningModule's model.
 
73
  """
74
- model.eval() # Ensure model is in evaluation mode
75
- transformer_model = model.model # Access the underlying Seq2SeqTransformer
76
 
77
  try:
78
  with torch.no_grad():
79
- # --- Encode Source ---
80
- memory = transformer_model.encode(
81
- src, src_padding_mask
82
- ) # [1, src_len, emb_size]
83
  memory = memory.to(device)
84
- memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
85
 
86
- # --- Initialize Target Sequence ---
87
- # Start with the SOS token
88
- ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
89
- sos_idx
90
- ) # [1, 1]
91
 
92
- # --- Decoding Loop ---
93
- for _ in range(max_len - 1): # Max length limit
94
  tgt_seq_len = ys.shape[1]
95
- tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
96
- device
97
- ) # [curr_len, curr_len]
98
- # No padding in target during generation yet
99
- tgt_padding_mask = torch.zeros(
100
- ys.shape, dtype=torch.bool, device=device
101
- ) # [1, curr_len]
102
-
103
- # Decode one step
104
  decoder_output = transformer_model.decode(
105
  tgt=ys,
106
  memory=memory,
107
  tgt_mask=tgt_mask,
108
  tgt_padding_mask=tgt_padding_mask,
109
  memory_key_padding_mask=memory_key_padding_mask,
110
- ) # [1, curr_len, emb_size]
111
-
112
- # Get logits for the *next* token prediction
113
- next_token_logits = transformer_model.generator(
114
- decoder_output[
115
- :, -1, :
116
- ] # Use output corresponding to the last input token
117
- ) # [1, tgt_vocab_size]
118
-
119
- # Find the most likely next token (greedy choice)
120
- # prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
121
- # _, next_word_id_tensor = torch.max(prob, dim=1)
122
- next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
123
- next_word_id = next_word_id_tensor.item()
124
-
125
- # Append the chosen token to the sequence
126
  ys = torch.cat(
127
  [
128
  ys,
129
- torch.ones(1, 1, dtype=torch.long, device=device).fill_(
130
- next_word_id
131
- ),
132
  ],
133
  dim=1,
134
- ) # [1, current_len + 1]
135
 
136
- # Stop if EOS token is generated
137
  if next_word_id == eos_idx:
138
  break
139
 
140
- # Return the generated sequence (excluding the initial SOS token)
141
- return ys[:, 1:] # Shape [1, generated_len]
142
 
143
  except RuntimeError as e:
144
  logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
145
  if "CUDA out of memory" in str(e) and device.type == "cuda":
146
  gc.collect()
147
  torch.cuda.empty_cache()
148
- return torch.empty(
149
- (1, 0), dtype=torch.long, device=device
150
- ) # Return empty tensor on error
151
  except Exception as e:
152
  logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
153
  return torch.empty((1, 0), dtype=torch.long, device=device)
154
 
155
 
156
- # --- Translation Function (Using Greedy Decode) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def translate(
158
  model: pl.LightningModule,
159
  src_sentence: str,
@@ -164,126 +298,139 @@ def translate(
164
  sos_idx: int,
165
  eos_idx: int,
166
  pad_idx: int,
167
- ) -> str: # Returns a single string
 
 
 
 
168
  """
169
- Translates a single SMILES string using greedy decoding.
170
  """
171
- model.eval() # Ensure model is in eval mode
172
 
173
  # --- Tokenize Source ---
174
  try:
175
- # Ensure tokenizer has truncation/padding configured if needed, or handle manually
176
- smiles_tokenizer.enable_truncation(
177
- max_length=max_len
178
- ) # Use max_len for source truncation too
179
  src_encoded = smiles_tokenizer.encode(src_sentence)
180
  if not src_encoded or not src_encoded.ids:
181
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
182
- return "[Encoding Error]"
183
- # Use the truncated IDs directly
184
  src_ids = src_encoded.ids
185
  except Exception as e:
186
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
187
- return "[Encoding Error]"
188
 
189
  # --- Prepare Input Tensor and Mask ---
190
- src = (
191
- torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
192
- ) # [1, src_len]
193
- # Create padding mask (True where it's a pad token, should be all False here unless tokenizer pads)
194
- src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
195
-
196
- # --- Perform Greedy Decoding ---
197
- # Calls the greedy_decode function defined *above in this file*
198
- # Note: max_len for generation should come from config if it dictates output length
199
- generation_max_len = config.get(
200
- "max_len", 256
201
- ) # Use config max_len for output limit
202
- tgt_tokens_tensor = greedy_decode(
203
- model=model,
204
- src=src,
205
- src_padding_mask=src_padding_mask,
206
- max_len=generation_max_len, # Use generation limit
207
- sos_idx=sos_idx,
208
- eos_idx=eos_idx,
209
- # pad_idx=pad_idx, # Not needed by greedy_decode internal loop
210
- device=device,
211
- ) # Returns a single tensor [1, generated_len]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  # --- Decode Generated Tokens ---
214
- if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
215
- logging.warning(
216
- f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
217
- )
218
- return "[Decoding Error - Empty Output]"
219
 
220
- tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
221
- try:
222
- # Decode using the target tokenizer, skipping special tokens
223
- translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
224
- return translation
225
- except Exception as e:
226
- logging.error(
227
- f"Error decoding target tokens {tgt_tokens}: {e}",
228
- exc_info=True,
229
- )
230
- return "[Decoding Error]"
 
 
231
 
232
 
233
- # --- Model/Tokenizer Loading Function (Unchanged) ---
234
  def load_model_and_tokenizers():
235
  """Loads tokenizers, config, and model from Hugging Face Hub."""
236
  global model, smiles_tokenizer, iupac_tokenizer, device, config
237
- if model is not None: # Already loaded
238
  logging.info("Model and tokenizers already loaded.")
239
  return
240
 
241
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
242
  try:
243
- # Determine device
244
- if torch.cuda.is_available():
245
- logging.warning(
246
- "CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
247
- )
248
- device = torch.device("cuda") # Uncomment if you want to use GPU
249
- # device = torch.device("cuda")
250
- # logging.info("CUDA available, using GPU.")
251
- else:
252
- device = torch.device("cpu")
253
- logging.info("CUDA not available, using CPU.")
254
 
255
  # Download files
256
  logging.info("Downloading files from Hugging Face Hub...")
 
 
 
 
257
  try:
258
- cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
259
- os.makedirs(cache_dir, exist_ok=True)
260
- logging.info(f"Using cache directory: {cache_dir}")
 
 
 
 
 
 
 
 
 
 
261
 
262
- checkpoint_path = hf_hub_download(
263
- repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
264
- )
265
- smiles_tokenizer_path = hf_hub_download(
266
- repo_id=MODEL_REPO_ID,
267
- filename=SMILES_TOKENIZER_FILENAME,
268
- cache_dir=cache_dir,
269
- )
270
- iupac_tokenizer_path = hf_hub_download(
271
- repo_id=MODEL_REPO_ID,
272
- filename=IUPAC_TOKENIZER_FILENAME,
273
- cache_dir=cache_dir,
274
- )
275
- config_path = hf_hub_download(
276
- repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
277
- )
278
  logging.info("Files downloaded successfully.")
279
  except Exception as e:
280
- logging.error(
281
- f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
282
- exc_info=True,
283
- )
284
- raise gr.Error(
285
- f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}"
286
- )
287
 
288
  # Load config
289
  logging.info("Loading configuration...")
@@ -291,181 +438,149 @@ def load_model_and_tokenizers():
291
  with open(config_path, "r") as f:
292
  config = json.load(f)
293
  logging.info("Configuration loaded.")
294
- # --- Validate essential config keys ---
295
- required_keys = [
296
- "src_vocab_size", # Use the key saved in config
297
- "tgt_vocab_size", # Use the key saved in config
298
- "emb_size",
299
- "nhead",
300
- "ffn_hid_dim",
301
- "num_encoder_layers",
302
- "num_decoder_layers",
303
- "dropout",
304
- "max_len",
305
- "pad_token_id",
306
- "bos_token_id",
307
- "eos_token_id",
308
- ]
309
- # Remap if needed (example shown, adjust if your keys differ)
310
- config_key_mapping = {
311
- "src_vocab_size": config.get(
312
- "src_vocab_size", config.get("src_vocab_size")
313
- ),
314
- "tgt_vocab_size": config.get(
315
- "tgt_vocab_size", config.get("tgt_vocab_size")
316
- ),
317
- # Add other mappings if necessary
318
- }
319
- config.update(config_key_mapping)
320
-
321
  missing_keys = [key for key in required_keys if config.get(key) is None]
322
  if missing_keys:
323
- raise ValueError(
324
- f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
325
- f"Ensure these were saved in the hyperparameters during training."
326
- )
327
-
328
- logging.info(
329
- f"Using config values: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
330
- f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
331
- f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
332
- )
333
-
334
- except FileNotFoundError:
335
- logging.error(f"Config file not found: {config_path}")
336
- raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
337
- except json.JSONDecodeError as e:
338
- logging.error(f"Error decoding JSON from config file {config_path}: {e}")
339
- raise gr.Error(
340
- f"Config Error: Could not parse '{CONFIG_FILENAME}'. Error: {e}"
341
- )
342
- except ValueError as e:
343
- logging.error(f"Config validation error: {e}")
344
- raise gr.Error(f"Config Error: {e}")
345
  except Exception as e:
346
- logging.error(f"Unexpected error loading config: {e}", exc_info=True)
347
- raise gr.Error(f"Config Error: Unexpected error. Check logs. Error: {e}")
 
348
 
349
  # Load tokenizers
350
  logging.info("Loading tokenizers...")
351
  try:
352
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
353
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
 
 
 
 
 
354
  logging.info("Tokenizers loaded.")
355
- # --- Optional: Validate Tokenizer Special Tokens Against Config ---
356
- # (Keep validation as before, it's still useful)
357
- pad_token = "<pad>"
358
- sos_token = "<sos>"
359
- eos_token = "<eos>"
360
- unk_token = "<unk>"
361
- issues = []
362
- # ... (keep the validation checks as in the original code) ...
363
- if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
364
- issues.append(f"SMILES PAD ID mismatch")
365
- if smiles_tokenizer.token_to_id(unk_token) is None:
366
- issues.append("SMILES UNK token not found")
367
- if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
368
- issues.append(f"IUPAC PAD ID mismatch")
369
- if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
370
- issues.append(f"IUPAC SOS ID mismatch")
371
- if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
372
- issues.append(f"IUPAC EOS ID mismatch")
373
- if iupac_tokenizer.token_to_id(unk_token) is None:
374
- issues.append("IUPAC UNK token not found")
375
- if issues:
376
- logging.warning("Tokenizer validation issues: " + "; ".join(issues))
377
-
378
  except Exception as e:
379
  logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
380
- raise gr.Error(
381
- f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
382
- )
383
 
384
  # Load model
385
  logging.info("Loading model from checkpoint...")
386
  try:
387
- # Use the vocab sizes and hparams from the loaded config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  model = SmilesIupacLitModule.load_from_checkpoint(
389
  checkpoint_path,
390
- **config, # Pass loaded config directly as keyword args
391
  map_location=device,
392
- devices=1,
393
- strict=True, # Start strict, set to False if encountering key errors
 
394
  )
395
 
396
  model.to(device)
397
  model.eval()
398
  model.freeze()
399
- logging.info(
400
- f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
401
- )
402
 
403
  except FileNotFoundError:
404
  logging.error(f"Checkpoint file not found: {checkpoint_path}")
405
- raise gr.Error(
406
- f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
407
- )
408
  except Exception as e:
409
- logging.error(
410
- f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
411
- )
412
  if "size mismatch" in str(e):
413
  error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
414
  logging.error(error_detail)
415
  raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
 
 
 
 
416
  elif "memory" in str(e).lower():
417
  logging.warning("Potential OOM error during model loading.")
418
  gc.collect()
419
  torch.cuda.empty_cache() if device.type == "cuda" else None
420
- raise gr.Error(
421
- f"Model Error: OOM loading model. Check Space resources. Error: {e}"
422
- )
423
  else:
424
- raise gr.Error(
425
- f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
426
- )
427
 
428
  except gr.Error:
429
- raise
430
  except Exception as e:
431
  logging.error(f"Unexpected error during loading: {e}", exc_info=True)
432
- raise gr.Error(
433
- f"Initialization Error: Unexpected error. Check logs. Error: {e}"
434
- )
435
 
436
 
437
- # --- Inference Function for Gradio (Simplified) ---
438
- @spaces.GPU
439
- def predict_iupac(smiles_string):
440
  """
441
- Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
442
  """
443
- try:
444
- smiles_string = CanonSmiles(smiles_string)
445
- except Exception as e:
446
- logging.error(f"Error during SMILES canonicalization: {e}", exc_info=True)
447
- return f"Error: {e}"
448
  global model, smiles_tokenizer, iupac_tokenizer, device, config
449
 
 
450
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
451
  error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
452
  logging.error(error_msg)
453
- return f"Error: {error_msg}" # Return single error string
454
 
455
  if not smiles_string or not smiles_string.strip():
456
- error_msg = "Error: Please enter a valid SMILES string."
457
- return f"Error: {error_msg}" # Return single error string
458
 
459
  smiles_input = smiles_string.strip()
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  try:
462
- # --- Call the core translation logic (greedy) ---
463
  sos_idx = config["bos_token_id"]
464
  eos_idx = config["eos_token_id"]
465
  pad_idx = config["pad_token_id"]
466
  gen_max_len = config["max_len"]
 
 
467
 
468
- predicted_name = translate( # Returns a single string now
469
  model=model,
470
  src_sentence=smiles_input,
471
  smiles_tokenizer=smiles_tokenizer,
@@ -475,100 +590,165 @@ def predict_iupac(smiles_string):
475
  sos_idx=sos_idx,
476
  eos_idx=eos_idx,
477
  pad_idx=pad_idx,
 
 
 
 
478
  )
479
- logging.info(f"Prediction returned: {predicted_name}")
480
 
481
  # --- Format Output ---
482
- if "[Error]" in predicted_name: # Check for error messages from translate
483
- output_text = (
484
- f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
485
- )
486
- elif not predicted_name:
487
- output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
 
 
 
 
 
 
 
488
  else:
489
- output_text = (
490
- f"Input SMILES: {smiles_input}\n\n"
491
- f"Predicted IUPAC Name (Greedy Decode):\n"
492
- f"{predicted_name}"
493
- )
494
- return output_text
 
 
495
 
496
  except RuntimeError as e:
497
  logging.error(f"Runtime error during translation: {e}", exc_info=True)
498
- return f"Error: {error_msg}" # Return single error string
499
-
 
500
  except Exception as e:
501
  logging.error(f"Unexpected error during translation: {e}", exc_info=True)
502
- error_msg = f"Unexpected Error during translation: {e}"
503
- return f"Error: {error_msg}" # Return single error string
504
 
505
 
506
  # --- Load Model on App Start ---
507
  try:
508
  load_model_and_tokenizers()
509
  except gr.Error as ge:
510
- logging.error(f"Gradio Initialization Error: {ge}")
511
- pass # Allow Gradio to potentially start with an error message
 
 
512
  except Exception as e:
513
  logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
514
- # Optionally raise gr.Error here too
515
 
516
 
517
- # --- Create Gradio Interface (Simplified) ---
518
- title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
519
  description = f"""
520
- Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
521
- Translation uses **greedy decoding** (picks the most likely next word at each step).
522
- **Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
 
523
  """
524
 
525
-
526
- # Use gr.Blocks exclusively
527
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface: # 'iface' is created here
528
  gr.Markdown(f"# {title}")
529
  gr.Markdown(description)
530
 
531
  with gr.Row():
532
- with gr.Column(scale=1): # Adjust scale if needed
533
  smiles_input = gr.Textbox(
534
  label="SMILES String",
535
- placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
536
- lines=1,
537
- # interactive=True # Default is True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  )
539
- submit_btn = gr.Button("Translate")
540
 
541
- with gr.Column(scale=2): # Give output more space if desired
 
542
  output_text = gr.Textbox(
543
- label="Predicted IUPAC Name",
544
- lines=5, # Increased lines slightly
545
  show_copy_button=True,
546
  # interactive=False # Output shouldn't be user-editable
547
  )
548
 
549
- # --- Define Event Listeners INSIDE the gr.Blocks context ---
550
- # When the button is clicked, call predict_iupac
551
  submit_btn.click(
552
  fn=predict_iupac,
553
- inputs=smiles_input,
554
  outputs=output_text,
555
- api_name="translate_smiles" # Optional: name for API endpoint
556
  )
557
 
558
- # Optional: Trigger prediction when text changes (can be resource-intensive)
559
- # If you uncomment this, consider adding a debounce or throttle if using Gradio >= 3.20
560
- # smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
561
-
562
- # Optional: Trigger prediction when text is submitted (e.g., pressing Enter)
563
  smiles_input.submit(
564
  fn=predict_iupac,
565
- inputs=smiles_input,
566
  outputs=output_text
567
  )
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  # --- Launch the App ---
570
- # The 'iface' variable is already defined by the 'with gr.Blocks(...)' statement
571
  if __name__ == "__main__":
572
- # You can add server_name="0.0.0.0" if running in Docker/Spaces
573
- # and share=True if you want a public link (usually handled by Spaces automatically)
574
  iface.launch()
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+ import torch.nn.functional as F # Needed for log_softmax in beam search
 
5
  import pytorch_lightning as pl
6
  import os
7
  import json
 
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
  import gc
12
+ from rdkit.Chem import CanonSmiles, MolFromSmiles # Added MolFromSmiles for validation
13
  import spaces
14
+ import heapq # For beam search priority queue
15
+ import math # For log probabilities
16
 
17
  # --- Configuration ---
18
  MODEL_REPO_ID = (
 
31
 
32
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
33
  try:
34
+ # Ensure enhanced_trainer.py is present in the repository root
35
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
36
 
37
  logging.info("Successfully imported from enhanced_trainer.py.")
 
59
  config: dict | None = None
60
 
61
 
62
+ # --- Greedy Decoding Logic (Unchanged) ---
63
  def greedy_decode(
64
  model: pl.LightningModule,
65
  src: torch.Tensor,
 
71
  ) -> torch.Tensor:
72
  """
73
  Performs greedy decoding using the LightningModule's model.
74
+ Returns a tensor of shape [1, sequence_length].
75
  """
76
+ model.eval()
77
+ transformer_model = model.model
78
 
79
  try:
80
  with torch.no_grad():
81
+ memory = transformer_model.encode(src, src_padding_mask)
 
 
 
82
  memory = memory.to(device)
83
+ memory_key_padding_mask = src_padding_mask.to(memory.device)
84
 
85
+ ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx)
 
 
 
 
86
 
87
+ for _ in range(max_len - 1):
 
88
  tgt_seq_len = ys.shape[1]
89
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device)
90
+ tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device)
91
+
 
 
 
 
 
 
92
  decoder_output = transformer_model.decode(
93
  tgt=ys,
94
  memory=memory,
95
  tgt_mask=tgt_mask,
96
  tgt_padding_mask=tgt_padding_mask,
97
  memory_key_padding_mask=memory_key_padding_mask,
98
+ )
99
+
100
+ next_token_logits = transformer_model.generator(decoder_output[:, -1, :])
101
+ next_word_id = torch.argmax(next_token_logits, dim=1).item()
102
+
 
 
 
 
 
 
 
 
 
 
 
103
  ys = torch.cat(
104
  [
105
  ys,
106
+ torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id),
 
 
107
  ],
108
  dim=1,
109
+ )
110
 
 
111
  if next_word_id == eos_idx:
112
  break
113
 
114
+ return ys[:, 1:] # Exclude SOS
 
115
 
116
  except RuntimeError as e:
117
  logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
118
  if "CUDA out of memory" in str(e) and device.type == "cuda":
119
  gc.collect()
120
  torch.cuda.empty_cache()
121
+ return torch.empty((1, 0), dtype=torch.long, device=device)
 
 
122
  except Exception as e:
123
  logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
124
  return torch.empty((1, 0), dtype=torch.long, device=device)
125
 
126
 
127
+ # --- Beam Search Decoding Logic ---
128
+ def beam_search_decode(
129
+ model: pl.LightningModule,
130
+ src: torch.Tensor,
131
+ src_padding_mask: torch.Tensor,
132
+ max_len: int,
133
+ sos_idx: int,
134
+ eos_idx: int,
135
+ pad_idx: int, # Needed for padding shorter beams if batching
136
+ device: torch.device,
137
+ beam_width: int,
138
+ num_return_sequences: int = 1,
139
+ length_penalty_alpha: float = 0.6, # Add length penalty
140
+ ) -> list[tuple[torch.Tensor, float]]:
141
+ """
142
+ Performs beam search decoding.
143
+ Returns a list of tuples: (sequence_tensor [1, seq_len], score)
144
+ """
145
+ model.eval()
146
+ transformer_model = model.model
147
+ num_return_sequences = min(beam_width, num_return_sequences) # Cannot return more than beam width
148
+
149
+ try:
150
+ with torch.no_grad():
151
+ # --- Encode Source (Once) ---
152
+ memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size]
153
+ memory = memory.to(device)
154
+ memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
155
+
156
+ # --- Initialize Beams ---
157
+ # Each beam: (sequence_tensor [1, current_len], score (log_prob))
158
+ initial_beam = (
159
+ torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx),
160
+ 0.0, # Initial score (log probability)
161
+ )
162
+ beams = [initial_beam]
163
+ finished_hypotheses = [] # Store finished sequences: (score, sequence_tensor)
164
+
165
+ # --- Decoding Loop ---
166
+ for step in range(max_len - 1):
167
+ if not beams: # Stop if no active beams left
168
+ break
169
+
170
+ # Use a min-heap to keep track of candidates for the *next* step
171
+ # Store (-score, next_token_id, beam_index) - use negative score for max-heap behavior
172
+ candidates = []
173
+
174
+ # Process current beams (can be batched for efficiency, but simpler loop shown here)
175
+ # For batching: stack ys, expand memory, create masks for the batch
176
+ for beam_idx, (current_seq, current_score) in enumerate(beams):
177
+ if current_seq[0, -1].item() == eos_idx: # Beam already finished
178
+ # Add length penalty before storing
179
+ penalty = ((current_seq.shape[1]) ** length_penalty_alpha)
180
+ final_score = current_score / penalty if penalty > 0 else current_score
181
+ heapq.heappush(finished_hypotheses, (final_score, current_seq))
182
+ # Prune finished hypotheses if we have enough
183
+ while len(finished_hypotheses) > beam_width:
184
+ heapq.heappop(finished_hypotheses) # Remove lowest score
185
+ continue # Don't expand finished beams
186
+
187
+ # --- Prepare input for this beam ---
188
+ ys = current_seq # [1, current_len]
189
+ tgt_seq_len = ys.shape[1]
190
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device)
191
+ tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device)
192
+
193
+ # --- Decode one step ---
194
+ # Note: memory and memory_key_padding_mask are reused
195
+ decoder_output = transformer_model.decode(
196
+ tgt=ys,
197
+ memory=memory, # Needs expansion if batching beams
198
+ tgt_mask=tgt_mask,
199
+ tgt_padding_mask=tgt_padding_mask,
200
+ memory_key_padding_mask=memory_key_padding_mask, # Needs expansion if batching
201
+ ) # [1, current_len, emb_size]
202
+
203
+ # Get logits for the *next* token
204
+ next_token_logits = transformer_model.generator(
205
+ decoder_output[:, -1, :]
206
+ ) # [1, tgt_vocab_size]
207
+
208
+ # Calculate log probabilities
209
+ log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size]
210
+
211
+ # Get top K candidates for *this* beam
212
+ # Adding current_score makes it the total path score
213
+ top_k_log_probs, top_k_indices = torch.topk(log_probs + current_score, beam_width, dim=1)
214
+
215
+ # Add candidates to the list for selection across all beams
216
+ for i in range(beam_width):
217
+ token_id = top_k_indices[0, i].item()
218
+ score = top_k_log_probs[0, i].item()
219
+ # Store (-score, token_id, beam_idx) for heap
220
+ heapq.heappush(candidates, (-score, token_id, beam_idx))
221
+ # Prune candidates heap if it exceeds beam_width * beam_width (can optimize)
222
+ # A simpler pruning: keep only top N overall candidates later
223
+
224
+ # --- Select Top K Beams for Next Step ---
225
+ new_beams = []
226
+ # Ensure we don't exceed beam_width overall candidates
227
+ num_candidates_to_consider = min(len(candidates), beam_width * len(beams)) # Rough upper bound
228
+
229
+ # Use heap to efficiently get top k candidates overall
230
+ top_candidates = heapq.nsmallest(beam_width, candidates) # Get k smallest (-score) -> largest score
231
+
232
+ added_sequences = set() # Prevent duplicate sequences if paths converge
233
+
234
+ for neg_score, token_id, beam_idx in top_candidates:
235
+ original_seq, _ = beams[beam_idx]
236
+ new_seq = torch.cat(
237
+ [
238
+ original_seq,
239
+ torch.ones(1, 1, dtype=torch.long, device=device).fill_(token_id),
240
+ ],
241
+ dim=1,
242
+ ) # [1, current_len + 1]
243
+
244
+ # Avoid adding duplicates (optional, but good practice)
245
+ seq_tuple = tuple(new_seq.flatten().tolist())
246
+ if seq_tuple not in added_sequences:
247
+ new_beams.append((new_seq, -neg_score)) # Store positive score
248
+ added_sequences.add(seq_tuple)
249
+
250
+ beams = new_beams # Update active beams
251
+
252
+ # Early stopping: If top beam is finished and we have enough results
253
+ if finished_hypotheses:
254
+ # Check if the best possible score from active beams is worse than the worst finished beam
255
+ best_active_score = -heapq.nsmallest(1, candidates)[0][0] if candidates else -float('inf')
256
+ worst_finished_score = finished_hypotheses[0][0] # Smallest score in min-heap
257
+ if len(finished_hypotheses) >= num_return_sequences and best_active_score < worst_finished_score:
258
+ logging.debug(f"Beam search early stopping at step {step}")
259
+ break
260
+
261
+
262
+ # --- Final Selection ---
263
+ # Add any remaining active beams to finished list (if they didn't end with EOS)
264
+ for seq, score in beams:
265
+ if seq[0, -1].item() != eos_idx:
266
+ penalty = ((seq.shape[1]) ** length_penalty_alpha)
267
+ final_score = score / penalty if penalty > 0 else score
268
+ heapq.heappush(finished_hypotheses, (final_score, seq))
269
+ while len(finished_hypotheses) > beam_width:
270
+ heapq.heappop(finished_hypotheses)
271
+
272
+ # Sort finished hypotheses by score (descending) and select top N
273
+ # heapq is min-heap, so nlargest gets the best scores
274
+ top_hypotheses = heapq.nlargest(num_return_sequences, finished_hypotheses)
275
+
276
+ # Return list of (sequence_tensor [1, seq_len], score) excluding SOS
277
+ return [(seq[:, 1:], score) for score, seq in top_hypotheses]
278
+
279
+ except RuntimeError as e:
280
+ logging.error(f"Runtime error during beam search: {e}", exc_info=True)
281
+ if "CUDA out of memory" in str(e) and device.type == "cuda":
282
+ gc.collect()
283
+ torch.cuda.empty_cache()
284
+ return [] # Return empty list on error
285
+ except Exception as e:
286
+ logging.error(f"Unexpected error during beam search: {e}", exc_info=True)
287
+ return []
288
+
289
+
290
+ # --- Translation Function (Handles both Greedy and Beam Search) ---
291
  def translate(
292
  model: pl.LightningModule,
293
  src_sentence: str,
 
298
  sos_idx: int,
299
  eos_idx: int,
300
  pad_idx: int,
301
+ decoding_strategy: str = "Greedy",
302
+ beam_width: int = 5,
303
+ num_return_sequences: int = 1,
304
+ length_penalty_alpha: float = 0.6,
305
+ ) -> list[tuple[str, float]]: # Returns list of (translation_string, score)
306
  """
307
+ Translates a single SMILES string using the specified decoding strategy.
308
  """
309
+ model.eval()
310
 
311
  # --- Tokenize Source ---
312
  try:
313
+ smiles_tokenizer.enable_truncation(max_length=max_len)
 
 
 
314
  src_encoded = smiles_tokenizer.encode(src_sentence)
315
  if not src_encoded or not src_encoded.ids:
316
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
317
+ return [("[Encoding Error]", 0.0)]
 
318
  src_ids = src_encoded.ids
319
  except Exception as e:
320
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
321
+ return [("[Encoding Error]", 0.0)]
322
 
323
  # --- Prepare Input Tensor and Mask ---
324
+ src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len]
325
+ src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
326
+
327
+ # --- Perform Decoding ---
328
+ generation_max_len = config.get("max_len", 256)
329
+ results = [] # List to store (tensor, score) tuples
330
+
331
+ if decoding_strategy == "Greedy":
332
+ tgt_tokens_tensor = greedy_decode(
333
+ model=model,
334
+ src=src,
335
+ src_padding_mask=src_padding_mask,
336
+ max_len=generation_max_len,
337
+ sos_idx=sos_idx,
338
+ eos_idx=eos_idx,
339
+ device=device,
340
+ ) # Returns tensor [1, generated_len]
341
+ if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0:
342
+ results = [(tgt_tokens_tensor, 0.0)] # Assign dummy score 0.0 for greedy
343
+ else:
344
+ logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}")
345
+ return [("[Decoding Error - Empty Output]", 0.0)]
346
+
347
+ elif decoding_strategy == "Beam Search":
348
+ results = beam_search_decode(
349
+ model=model,
350
+ src=src,
351
+ src_padding_mask=src_padding_mask,
352
+ max_len=generation_max_len,
353
+ sos_idx=sos_idx,
354
+ eos_idx=eos_idx,
355
+ pad_idx=pad_idx,
356
+ device=device,
357
+ beam_width=beam_width,
358
+ num_return_sequences=num_return_sequences,
359
+ length_penalty_alpha=length_penalty_alpha,
360
+ ) # Returns list of (tensor, score)
361
+ if not results:
362
+ logging.warning(f"Beam search returned no results for SMILES: {src_sentence}")
363
+ return [("[Decoding Error - Empty Output]", 0.0)]
364
+ else:
365
+ logging.error(f"Unknown decoding strategy: {decoding_strategy}")
366
+ return [("[Error: Unknown Strategy]", 0.0)]
367
+
368
 
369
  # --- Decode Generated Tokens ---
370
+ translations = []
371
+ for tgt_tokens_tensor, score in results:
372
+ if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
373
+ translations.append(("[Decoding Error - Empty Sequence]", score))
374
+ continue
375
 
376
+ tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
377
+ try:
378
+ # Decode using the target tokenizer, skipping special tokens
379
+ translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
380
+ translations.append((translation, score))
381
+ except Exception as e:
382
+ logging.error(
383
+ f"Error decoding target tokens {tgt_tokens}: {e}",
384
+ exc_info=True,
385
+ )
386
+ translations.append(("[Decoding Error]", score))
387
+
388
+ return translations
389
 
390
 
391
+ # --- Model/Tokenizer Loading Function (Unchanged from previous version) ---
392
  def load_model_and_tokenizers():
393
  """Loads tokenizers, config, and model from Hugging Face Hub."""
394
  global model, smiles_tokenizer, iupac_tokenizer, device, config
395
+ if model is not None:
396
  logging.info("Model and tokenizers already loaded.")
397
  return
398
 
399
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
400
  try:
401
+ # Determine device (Force CPU for stability in typical Space envs, uncomment cuda if needed)
402
+ # if torch.cuda.is_available():
403
+ # device = torch.device("cuda")
404
+ # logging.info("CUDA available, using GPU.")
405
+ # else:
406
+ device = torch.device("cpu")
407
+ logging.info("Using CPU. Modify code to enable GPU if available and desired.")
 
 
 
 
408
 
409
  # Download files
410
  logging.info("Downloading files from Hugging Face Hub...")
411
+ cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
412
+ os.makedirs(cache_dir, exist_ok=True)
413
+ logging.info(f"Using cache directory: {cache_dir}")
414
+
415
  try:
416
+ checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir)
417
+ smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir)
418
+ iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir)
419
+ config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir)
420
+ # Ensure enhanced_trainer.py is downloaded or present
421
+ try:
422
+ hf_hub_download(repo_id=MODEL_REPO_ID, filename="enhanced_trainer.py", cache_dir=cache_dir, local_dir=".") # Download to current dir
423
+ logging.info("Downloaded enhanced_trainer.py")
424
+ except Exception as download_err:
425
+ if os.path.exists("enhanced_trainer.py"):
426
+ logging.warning(f"Could not download enhanced_trainer.py (maybe private?), but found local file. Using local. Error: {download_err}")
427
+ else:
428
+ raise download_err # Re-raise if not found locally either
429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  logging.info("Files downloaded successfully.")
431
  except Exception as e:
432
+ logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True)
433
+ raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}")
 
 
 
 
 
434
 
435
  # Load config
436
  logging.info("Loading configuration...")
 
438
  with open(config_path, "r") as f:
439
  config = json.load(f)
440
  logging.info("Configuration loaded.")
441
+ required_keys = ["src_vocab_size", "tgt_vocab_size", "emb_size", "nhead", "ffn_hid_dim", "num_encoder_layers", "num_decoder_layers", "dropout", "max_len", "pad_token_id", "bos_token_id", "eos_token_id"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  missing_keys = [key for key in required_keys if config.get(key) is None]
443
  if missing_keys:
444
+ raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}.")
445
+ logging.info(f"Using config: { {k: config.get(k) for k in required_keys} }") # Log key values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  except Exception as e:
447
+ logging.error(f"Error loading or validating config: {e}", exc_info=True)
448
+ raise gr.Error(f"Config Error: {e}")
449
+
450
 
451
  # Load tokenizers
452
  logging.info("Loading tokenizers...")
453
  try:
454
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
455
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
456
+ # Basic validation (can add more checks as before)
457
+ if smiles_tokenizer.get_vocab_size() != config['src_vocab_size']:
458
+ logging.warning(f"SMILES vocab size mismatch: Tokenizer={smiles_tokenizer.get_vocab_size()}, Config={config['src_vocab_size']}")
459
+ if iupac_tokenizer.get_vocab_size() != config['tgt_vocab_size']:
460
+ logging.warning(f"IUPAC vocab size mismatch: Tokenizer={iupac_tokenizer.get_vocab_size()}, Config={config['tgt_vocab_size']}")
461
  logging.info("Tokenizers loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  except Exception as e:
463
  logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
464
+ raise gr.Error(f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}")
 
 
465
 
466
  # Load model
467
  logging.info("Loading model from checkpoint...")
468
  try:
469
+ # Ensure config keys match expected arguments of SmilesIupacLitModule.__init__
470
+ # Map config keys if necessary, e.g., if config uses 'vocab_size_src' but class expects 'src_vocab_size'
471
+ model_hparams = config.copy() # Start with all config params
472
+
473
+ # Example remapping (adjust if your config/class names differ):
474
+ # model_hparams['src_vocab_size'] = model_hparams.pop('vocab_size_src', config['src_vocab_size'])
475
+ # model_hparams['tgt_vocab_size'] = model_hparams.pop('vocab_size_tgt', config['tgt_vocab_size'])
476
+ # model_hparams['bos_idx'] = model_hparams.pop('bos_token_id', config['bos_token_id'])
477
+ # model_hparams['eos_idx'] = model_hparams.pop('eos_token_id', config['eos_token_id'])
478
+ # model_hparams['padding_idx'] = model_hparams.pop('pad_token_id', config['pad_token_id'])
479
+
480
+ # Remove keys from hparams that are not expected by the LitModule's __init__
481
+ # This depends on the exact signature of SmilesIupacLitModule
482
+ # Common ones to potentially remove if not direct args: max_len (often used elsewhere)
483
+ # Check the __init__ signature of SmilesIupacLitModule in enhanced_trainer.py
484
+ expected_args = SmilesIupacLitModule.__init__.__code__.co_varnames
485
+ hparams_to_pass = {k: v for k, v in model_hparams.items() if k in expected_args}
486
+ logging.info(f"Passing hparams to LitModule: {hparams_to_pass.keys()}")
487
+
488
+
489
  model = SmilesIupacLitModule.load_from_checkpoint(
490
  checkpoint_path,
 
491
  map_location=device,
492
+ # devices=1, # Often not needed for inference loading
493
+ strict=False, # Set to False initially if encountering key errors
494
+ **hparams_to_pass # Pass relevant hparams from config
495
  )
496
 
497
  model.to(device)
498
  model.eval()
499
  model.freeze()
500
+ logging.info(f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'.")
 
 
501
 
502
  except FileNotFoundError:
503
  logging.error(f"Checkpoint file not found: {checkpoint_path}")
504
+ raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
 
 
505
  except Exception as e:
506
+ logging.error(f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
 
 
507
  if "size mismatch" in str(e):
508
  error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
509
  logging.error(error_detail)
510
  raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
511
+ elif "unexpected keyword argument" in str(e) or "missing 1 required positional argument" in str(e):
512
+ error_detail = f"Mismatch between config.json keys and SmilesIupacLitModule constructor arguments. Check enhanced_trainer.py and config.json. Error: {e}"
513
+ logging.error(error_detail)
514
+ raise gr.Error(f"Model Error: {error_detail}")
515
  elif "memory" in str(e).lower():
516
  logging.warning("Potential OOM error during model loading.")
517
  gc.collect()
518
  torch.cuda.empty_cache() if device.type == "cuda" else None
519
+ raise gr.Error(f"Model Error: OOM loading model. Check Space resources. Error: {e}")
 
 
520
  else:
521
+ raise gr.Error(f"Model Error: Failed to load checkpoint. Check logs. Error: {e}")
 
 
522
 
523
  except gr.Error:
524
+ raise # Propagate Gradio errors directly
525
  except Exception as e:
526
  logging.error(f"Unexpected error during loading: {e}", exc_info=True)
527
+ raise gr.Error(f"Initialization Error: Unexpected error. Check logs. Error: {e}")
 
 
528
 
529
 
530
+ # --- Inference Function for Gradio ---
531
+ # @spaces.GPU # Uncomment if using GPU and have appropriate hardware tier
532
+ def predict_iupac(smiles_string, decoding_strategy, num_beams, num_return_sequences):
533
  """
534
+ Performs SMILES to IUPAC translation using the loaded model and selected strategy.
535
  """
 
 
 
 
 
536
  global model, smiles_tokenizer, iupac_tokenizer, device, config
537
 
538
+ # --- Input Validation ---
539
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
540
  error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
541
  logging.error(error_msg)
542
+ return f"Initialization Error: {error_msg}"
543
 
544
  if not smiles_string or not smiles_string.strip():
545
+ return "Error: Please enter a valid SMILES string."
 
546
 
547
  smiles_input = smiles_string.strip()
548
 
549
+ # Validate SMILES using RDKit
550
+ try:
551
+ mol = MolFromSmiles(smiles_input)
552
+ if mol is None:
553
+ return f"Error: Invalid SMILES string provided: '{smiles_input}'"
554
+ smiles_input = CanonSmiles(smiles_input) # Use canonical form
555
+ logging.info(f"Canonical SMILES: {smiles_input}")
556
+ except Exception as e:
557
+ logging.error(f"Error during SMILES validation/canonicalization: {e}", exc_info=True)
558
+ return f"Error: Could not process SMILES string '{smiles_input}'. RDKit error: {e}"
559
+
560
+ # Validate beam search parameters
561
+ if decoding_strategy == "Beam Search":
562
+ if not isinstance(num_beams, int) or num_beams <= 0:
563
+ return "Error: Beam width must be a positive integer."
564
+ if not isinstance(num_return_sequences, int) or num_return_sequences <= 0:
565
+ return "Error: Number of return sequences must be a positive integer."
566
+ if num_return_sequences > num_beams:
567
+ return f"Error: Number of return sequences ({num_return_sequences}) cannot exceed beam width ({num_beams})."
568
+ else:
569
+ # Ensure defaults are used for greedy
570
+ num_beams = 1
571
+ num_return_sequences = 1
572
+
573
+
574
  try:
575
+ # --- Call the core translation logic ---
576
  sos_idx = config["bos_token_id"]
577
  eos_idx = config["eos_token_id"]
578
  pad_idx = config["pad_token_id"]
579
  gen_max_len = config["max_len"]
580
+ # Use fixed length penalty for now, could be another slider
581
+ length_penalty = 0.6
582
 
583
+ predicted_results = translate( # Returns list of (name, score)
584
  model=model,
585
  src_sentence=smiles_input,
586
  smiles_tokenizer=smiles_tokenizer,
 
590
  sos_idx=sos_idx,
591
  eos_idx=eos_idx,
592
  pad_idx=pad_idx,
593
+ decoding_strategy=decoding_strategy,
594
+ beam_width=num_beams,
595
+ num_return_sequences=num_return_sequences,
596
+ length_penalty_alpha=length_penalty,
597
  )
598
+ logging.info(f"Prediction returned {len(predicted_results)} result(s). Strategy: {decoding_strategy}, Beams: {num_beams}, Return: {num_return_sequences}")
599
 
600
  # --- Format Output ---
601
+ output_lines = []
602
+ output_lines.append(f"Input SMILES: {smiles_input}")
603
+ output_lines.append(f"Decoding Strategy: {decoding_strategy}")
604
+ if decoding_strategy == "Beam Search":
605
+ output_lines.append(f"Beam Width: {num_beams}")
606
+ output_lines.append(f"Returned Sequences: {len(predicted_results)}")
607
+ output_lines.append(f"Length Penalty Alpha: {length_penalty:.2f}")
608
+
609
+
610
+ output_lines.append("\n--- Predictions ---")
611
+
612
+ if not predicted_results:
613
+ output_lines.append("No predictions generated.")
614
  else:
615
+ for i, (name, score) in enumerate(predicted_results):
616
+ if "[Error]" in name or not name:
617
+ output_lines.append(f"{i+1}. Prediction Failed: {name}")
618
+ else:
619
+ score_info = f"(Score: {score:.4f})" if decoding_strategy == "Beam Search" else ""
620
+ output_lines.append(f"{i+1}. {name} {score_info}")
621
+
622
+ return "\n".join(output_lines)
623
 
624
  except RuntimeError as e:
625
  logging.error(f"Runtime error during translation: {e}", exc_info=True)
626
+ gc.collect()
627
+ if device.type == 'cuda': torch.cuda.empty_cache()
628
+ return f"Runtime Error during translation: {e}. Check logs."
629
  except Exception as e:
630
  logging.error(f"Unexpected error during translation: {e}", exc_info=True)
631
+ return f"Unexpected Error during translation: {e}. Check logs."
 
632
 
633
 
634
  # --- Load Model on App Start ---
635
  try:
636
  load_model_and_tokenizers()
637
  except gr.Error as ge:
638
+ # Log the Gradio error but allow interface to load potentially showing the error message
639
+ logging.error(f"Gradio Initialization Error during load: {ge}")
640
+ # Display error in the UI if possible? Hard to do before UI is built.
641
+ # We rely on the predict function checking for loaded components.
642
  except Exception as e:
643
  logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
644
+ # This might prevent the app from starting correctly.
645
 
646
 
647
+ # --- Create Gradio Interface ---
648
+ title = "SMILES to IUPAC Name Translator"
649
  description = f"""
650
+ Translate a SMILES string into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}).
651
+ Choose between **Greedy Decoding** (fastest, picks the most likely next word) and **Beam Search Decoding** (explores multiple possibilities, potentially better results, slower).
652
+ **Note:** Model loaded on **{str(device).upper() if device else 'N/A'}**. Beam search can be slow, especially with larger beam widths.
653
+ Check `config.json` in the repo for model details. SMILES input will be canonicalized using RDKit.
654
  """
655
 
656
+ # Use gr.Blocks for more layout control
657
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
 
658
  gr.Markdown(f"# {title}")
659
  gr.Markdown(description)
660
 
661
  with gr.Row():
662
+ with gr.Column(scale=1): # Input column
663
  smiles_input = gr.Textbox(
664
  label="SMILES String",
665
+ placeholder="Enter SMILES string (e.g., CCO, c1ccccc1)",
666
+ lines=2,
667
+ )
668
+ with gr.Accordion("Decoding Options", open=False): # Options collapsible
669
+ decode_strategy = gr.Radio(
670
+ ["Greedy", "Beam Search"],
671
+ label="Decoding Strategy",
672
+ value="Greedy",
673
+ info="Greedy is faster, Beam Search may be more accurate."
674
+ )
675
+ beam_width_slider = gr.Slider(
676
+ minimum=1,
677
+ maximum=20, # Keep max reasonable for performance
678
+ step=1,
679
+ value=5,
680
+ label="Beam Width",
681
+ info="Number of beams to explore (Beam Search only)",
682
+ visible=False # Initially hidden
683
+ )
684
+ num_seq_slider = gr.Slider(
685
+ minimum=1,
686
+ maximum=5, # Keep max reasonable
687
+ step=1,
688
+ value=1,
689
+ label="Number of Results",
690
+ info="How many sequences to return (Beam Search only)",
691
+ visible=False # Initially hidden
692
+ )
693
+
694
+ submit_btn = gr.Button("Translate", variant="primary")
695
+
696
+ # --- Logic to show/hide beam search options ---
697
+ def update_beam_options(strategy):
698
+ is_beam = strategy == "Beam Search"
699
+ return {
700
+ beam_width_slider: gr.update(visible=is_beam),
701
+ num_seq_slider: gr.update(visible=is_beam)
702
+ }
703
+
704
+ decode_strategy.change(
705
+ fn=update_beam_options,
706
+ inputs=decode_strategy,
707
+ outputs=[beam_width_slider, num_seq_slider]
708
  )
 
709
 
710
+
711
+ with gr.Column(scale=2): # Output column
712
  output_text = gr.Textbox(
713
+ label="Translation Results",
714
+ lines=10, # More lines for potentially multiple results
715
  show_copy_button=True,
716
  # interactive=False # Output shouldn't be user-editable
717
  )
718
 
719
+ # --- Define Event Listeners ---
 
720
  submit_btn.click(
721
  fn=predict_iupac,
722
+ inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider],
723
  outputs=output_text,
724
+ api_name="translate_smiles"
725
  )
726
 
727
+ # Trigger on Enter press in the SMILES box
 
 
 
 
728
  smiles_input.submit(
729
  fn=predict_iupac,
730
+ inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider],
731
  outputs=output_text
732
  )
733
 
734
+ # Add examples
735
+ gr.Examples(
736
+ examples=[
737
+ ["CCO", "Greedy", 1, 1],
738
+ ["c1ccccc1", "Greedy", 1, 1],
739
+ ["CC(C)Br", "Beam Search", 5, 3],
740
+ ["C[C@H](O)c1ccccc1", "Beam Search", 10, 5],
741
+ ["INVALID_SMILES", "Greedy", 1, 1], # Example of invalid input
742
+ ["N#CC(C)(C)OC(=O)C(C)=C", "Beam Search", 8, 2]
743
+ ],
744
+ inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], # Match inputs order
745
+ outputs=output_text, # Output component
746
+ fn=predict_iupac, # Function to run for examples
747
+ cache_examples=False, # Caching might be tricky with model state
748
+ label="Example SMILES & Settings"
749
+ )
750
+
751
+
752
  # --- Launch the App ---
 
753
  if __name__ == "__main__":
 
 
754
  iface.launch()