AdrianM0 commited on
Commit
01fa093
·
verified ·
1 Parent(s): fae0efa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +331 -235
app.py CHANGED
@@ -1,7 +1,6 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
4
-
5
  # import torch.nn.functional as F # No longer needed for greedy decode directly
6
  import pytorch_lightning as pl
7
  import os
@@ -10,8 +9,13 @@ import logging
10
  from tokenizers import Tokenizer
11
  from huggingface_hub import hf_hub_download
12
  import gc
13
- from rdkit.Chem import CanonSmiles
14
-
 
 
 
 
 
15
 
16
  # --- Configuration ---
17
  MODEL_REPO_ID = (
@@ -30,24 +34,24 @@ logging.basicConfig(
30
 
31
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
32
  try:
 
33
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
34
-
35
  logging.info("Successfully imported from enhanced_trainer.py.")
36
  except ImportError as e:
37
  logging.error(
38
  f"Failed to import helper code from enhanced_trainer.py: {e}. "
39
  f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
40
  )
41
- raise gr.Error(
42
- f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
43
- )
44
  except Exception as e:
45
  logging.error(
46
  f"An unexpected error occurred during helper code import: {e}", exc_info=True
47
  )
48
- raise gr.Error(
49
- f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
50
- )
51
 
52
  # --- Global Variables (Load Model Once) ---
53
  model: pl.LightningModule | None = None
@@ -69,87 +73,107 @@ def greedy_decode(
69
  ) -> torch.Tensor:
70
  """
71
  Performs greedy decoding using the LightningModule's model.
 
72
  """
 
 
 
 
 
 
 
 
 
 
73
  model.eval() # Ensure model is in evaluation mode
74
  transformer_model = model.model # Access the underlying Seq2SeqTransformer
75
 
76
  try:
77
  with torch.no_grad():
78
  # --- Encode Source ---
 
79
  memory = transformer_model.encode(
80
- src, src_padding_mask
81
- ) # [1, src_len, emb_size]
 
 
82
  memory = memory.to(device)
83
- memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
 
 
84
 
85
  # --- Initialize Target Sequence ---
86
- # Start with the SOS token
87
  ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
88
  sos_idx
89
- ) # [1, 1]
90
 
91
  # --- Decoding Loop ---
92
- for _ in range(max_len - 1): # Max length limit
 
 
 
 
 
 
93
  tgt_seq_len = ys.shape[1]
 
94
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
95
  device
96
- ) # [curr_len, curr_len]
97
- # No padding in target during generation yet
 
98
  tgt_padding_mask = torch.zeros(
99
  ys.shape, dtype=torch.bool, device=device
100
- ) # [1, curr_len]
 
 
 
 
101
 
102
  # Decode one step
103
  decoder_output = transformer_model.decode(
104
- tgt=ys,
105
- memory=memory,
106
- tgt_mask=tgt_mask,
107
- tgt_padding_mask=tgt_padding_mask,
108
- memory_key_padding_mask=memory_key_padding_mask,
109
- ) # [1, curr_len, emb_size]
110
 
111
  # Get logits for the *next* token prediction
 
112
  next_token_logits = transformer_model.generator(
113
- decoder_output[
114
- :, -1, :
115
- ] # Use output corresponding to the last input token
116
- ) # [1, tgt_vocab_size]
117
 
118
  # Find the most likely next token (greedy choice)
119
- # prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
120
- # _, next_word_id_tensor = torch.max(prob, dim=1)
121
- next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
122
  next_word_id = next_word_id_tensor.item()
123
 
124
- # Append the chosen token to the sequence
125
- ys = torch.cat(
126
- [
127
- ys,
128
- torch.ones(1, 1, dtype=torch.long, device=device).fill_(
129
- next_word_id
130
- ),
131
- ],
132
- dim=1,
133
- ) # [1, current_len + 1]
134
 
135
  # Stop if EOS token is generated
136
  if next_word_id == eos_idx:
137
  break
138
 
139
  # Return the generated sequence (excluding the initial SOS token)
140
- return ys[:, 1:] # Shape [1, generated_len]
 
141
 
142
  except RuntimeError as e:
143
  logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
144
  if "CUDA out of memory" in str(e) and device.type == "cuda":
145
  gc.collect()
146
  torch.cuda.empty_cache()
147
- return torch.empty(
148
- (1, 0), dtype=torch.long, device=device
149
- ) # Return empty tensor on error
150
  except Exception as e:
151
  logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
152
- return torch.empty((1, 0), dtype=torch.long, device=device)
153
 
154
 
155
  # --- Translation Function (Using Greedy Decode) ---
@@ -159,94 +183,108 @@ def translate(
159
  smiles_tokenizer: Tokenizer,
160
  iupac_tokenizer: Tokenizer,
161
  device: torch.device,
162
- max_len: int,
163
  sos_idx: int,
164
  eos_idx: int,
165
  pad_idx: int,
166
- ) -> str: # Returns a single string
167
  """
168
  Translates a single SMILES string using greedy decoding.
169
  """
 
 
 
170
  model.eval() # Ensure model is in eval mode
171
 
172
  # --- Tokenize Source ---
173
  try:
174
- # Ensure tokenizer has truncation/padding configured if needed, or handle manually
175
- smiles_tokenizer.enable_truncation(
176
- max_length=max_len
177
- ) # Use max_len for source truncation too
178
  src_encoded = smiles_tokenizer.encode(src_sentence)
179
  if not src_encoded or not src_encoded.ids:
180
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
181
- return "[Encoding Error]"
182
- # Use the truncated IDs directly
183
  src_ids = src_encoded.ids
 
 
 
 
 
184
  except Exception as e:
185
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
186
- return "[Encoding Error]"
187
 
188
  # --- Prepare Input Tensor and Mask ---
189
- src = (
190
- torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
191
- ) # [1, src_len]
192
- # Create padding mask (True where it's a pad token, should be all False here unless tokenizer pads)
193
- src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
194
 
195
  # --- Perform Greedy Decoding ---
196
- # Calls the greedy_decode function defined *above in this file*
197
- # Note: max_len for generation should come from config if it dictates output length
198
- generation_max_len = config.get(
199
- "max_len", 256
200
- ) # Use config max_len for output limit
201
- tgt_tokens_tensor = greedy_decode(
202
- model=model,
203
- src=src,
204
- src_padding_mask=src_padding_mask,
205
- max_len=generation_max_len, # Use generation limit
206
- sos_idx=sos_idx,
207
- eos_idx=eos_idx,
208
- # pad_idx=pad_idx, # Not needed by greedy_decode internal loop
209
- device=device,
210
- ) # Returns a single tensor [1, generated_len]
211
 
212
  # --- Decode Generated Tokens ---
213
  if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
214
- logging.warning(
215
- f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
216
- )
217
- return "[Decoding Error - Empty Output]"
 
 
 
 
 
218
 
219
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
220
  try:
221
  # Decode using the target tokenizer, skipping special tokens
222
  translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
223
- return translation
224
  except Exception as e:
225
  logging.error(
226
  f"Error decoding target tokens {tgt_tokens}: {e}",
227
  exc_info=True,
228
  )
229
- return "[Decoding Error]"
230
 
231
 
232
- # --- Model/Tokenizer Loading Function (Unchanged) ---
233
  def load_model_and_tokenizers():
234
  """Loads tokenizers, config, and model from Hugging Face Hub."""
235
- global model, smiles_tokenizer, iupac_tokenizer, device, config
236
  if model is not None: # Already loaded
237
  logging.info("Model and tokenizers already loaded.")
238
  return
239
 
240
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
 
 
 
 
 
 
 
241
  try:
242
  # Determine device
243
  if torch.cuda.is_available():
244
- logging.warning(
245
- "CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
246
- )
247
- device = torch.device("cuda") # Uncomment if GPU is intended
248
- # device = torch.device("cuda")
249
- # logging.info("CUDA available, using GPU.")
250
  else:
251
  device = torch.device("cpu")
252
  logging.info("CUDA not available, using CPU.")
@@ -254,27 +292,25 @@ def load_model_and_tokenizers():
254
  # Download files
255
  logging.info("Downloading files from Hugging Face Hub...")
256
  try:
 
257
  cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
258
- os.makedirs(cache_dir, exist_ok=True)
259
  logging.info(f"Using cache directory: {cache_dir}")
260
 
 
261
  checkpoint_path = hf_hub_download(
262
- repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
263
  )
264
  smiles_tokenizer_path = hf_hub_download(
265
- repo_id=MODEL_REPO_ID,
266
- filename=SMILES_TOKENIZER_FILENAME,
267
- cache_dir=cache_dir,
268
  )
269
  iupac_tokenizer_path = hf_hub_download(
270
- repo_id=MODEL_REPO_ID,
271
- filename=IUPAC_TOKENIZER_FILENAME,
272
- cache_dir=cache_dir,
273
  )
274
  config_path = hf_hub_download(
275
- repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
276
  )
277
- logging.info("Files downloaded successfully.")
278
  except Exception as e:
279
  logging.error(
280
  f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
@@ -292,30 +328,35 @@ def load_model_and_tokenizers():
292
  logging.info("Configuration loaded.")
293
  # --- Validate essential config keys ---
294
  required_keys = [
295
- "src_vocab_size", # Use the key saved in config
296
- "tgt_vocab_size", # Use the key saved in config
297
  "emb_size",
298
  "nhead",
299
  "ffn_hid_dim",
300
  "num_encoder_layers",
301
  "num_decoder_layers",
302
  "dropout",
303
- "max_len",
304
  "pad_token_id",
305
  "bos_token_id",
306
  "eos_token_id",
307
  ]
308
- # Remap if needed (example shown, adjust if your keys differ)
309
- config_key_mapping = {
310
- "src_vocab_size": config.get(
311
- "src_vocab_size", config.get("src_vocab_size")
312
- ),
313
- "tgt_vocab_size": config.get(
314
- "tgt_vocab_size", config.get("tgt_vocab_size")
315
- ),
316
- # Add other mappings if necessary
317
- }
318
- config.update(config_key_mapping)
 
 
 
 
 
319
 
320
  missing_keys = [key for key in required_keys if config.get(key) is None]
321
  if missing_keys:
@@ -325,7 +366,7 @@ def load_model_and_tokenizers():
325
  )
326
 
327
  logging.info(
328
- f"Using config values: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
329
  f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
330
  f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
331
  )
@@ -350,51 +391,75 @@ def load_model_and_tokenizers():
350
  try:
351
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
352
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
353
- logging.info("Tokenizers loaded.")
354
- # --- Optional: Validate Tokenizer Special Tokens Against Config ---
355
- # (Keep validation as before, it's still useful)
356
  pad_token = "<pad>"
357
  sos_token = "<sos>"
358
  eos_token = "<eos>"
359
- unk_token = "<unk>"
360
  issues = []
361
- # ... (keep the validation checks as in the original code) ...
362
- if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
363
- issues.append(f"SMILES PAD ID mismatch")
364
- if smiles_tokenizer.token_to_id(unk_token) is None:
365
- issues.append("SMILES UNK token not found")
366
- if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
367
- issues.append(f"IUPAC PAD ID mismatch")
368
- if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
369
- issues.append(f"IUPAC SOS ID mismatch")
370
- if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
371
- issues.append(f"IUPAC EOS ID mismatch")
372
- if iupac_tokenizer.token_to_id(unk_token) is None:
373
- issues.append("IUPAC UNK token not found")
 
 
 
 
 
 
 
 
 
 
374
  if issues:
375
- logging.warning("Tokenizer validation issues: " + "; ".join(issues))
 
 
 
 
376
 
377
  except Exception as e:
378
  logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
379
  raise gr.Error(
380
- f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
381
  )
382
 
383
  # Load model
384
  logging.info("Loading model from checkpoint...")
385
  try:
386
- # Use the vocab sizes and hparams from the loaded config
 
 
 
 
 
 
 
 
 
387
  model = SmilesIupacLitModule.load_from_checkpoint(
388
  checkpoint_path,
389
- **config, # Pass loaded config directly as keyword args
390
  map_location=device,
391
- devices=1,
392
- strict=True, # Start strict, set to False if encountering key errors
 
 
393
  )
394
 
 
395
  model.to(device)
396
  model.eval()
397
- model.freeze()
398
  logging.info(
399
  f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
400
  )
@@ -402,172 +467,203 @@ def load_model_and_tokenizers():
402
  except FileNotFoundError:
403
  logging.error(f"Checkpoint file not found: {checkpoint_path}")
404
  raise gr.Error(
405
- f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
406
  )
407
- except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  logging.error(
409
- f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
 
 
 
410
  )
411
- if "size mismatch" in str(e):
412
- error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
413
- logging.error(error_detail)
414
- raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
415
- elif "memory" in str(e).lower():
416
- logging.warning("Potential OOM error during model loading.")
417
- gc.collect()
418
- torch.cuda.empty_cache() if device.type == "cuda" else None
419
- raise gr.Error(
420
- f"Model Error: OOM loading model. Check Space resources. Error: {e}"
421
- )
422
- else:
423
- raise gr.Error(
424
- f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
425
- )
426
 
427
- except gr.Error:
428
- raise
429
- except Exception as e:
430
- logging.error(f"Unexpected error during loading: {e}", exc_info=True)
431
  raise gr.Error(
432
- f"Initialization Error: Unexpected error. Check logs. Error: {e}"
433
  )
434
 
435
 
436
- # --- Inference Function for Gradio (Simplified) ---
437
  def predict_iupac(smiles_string):
438
  """
439
  Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
 
440
  """
441
- try:
442
- smiles_string = CanonSmiles(smiles_string)
443
- except Exception as e:
444
- logging.error(f"Error during SMILES canonicalization: {e}", exc_info=True)
445
- return f"Error: {e}"
446
  global model, smiles_tokenizer, iupac_tokenizer, device, config
447
 
 
448
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
449
- error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
450
  logging.error(error_msg)
451
- return f"Error: {error_msg}" # Return single error string
 
452
 
 
453
  if not smiles_string or not smiles_string.strip():
454
- error_msg = "Error: Please enter a valid SMILES string."
455
- return f"Error: {error_msg}" # Return single error string
456
 
457
- smiles_input = smiles_string.strip()
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  try:
460
- # --- Call the core translation logic (greedy) ---
461
  sos_idx = config["bos_token_id"]
462
  eos_idx = config["eos_token_id"]
463
  pad_idx = config["pad_token_id"]
464
- gen_max_len = config["max_len"]
465
 
466
- predicted_name = translate( # Returns a single string now
467
  model=model,
468
- src_sentence=smiles_input,
469
  smiles_tokenizer=smiles_tokenizer,
470
  iupac_tokenizer=iupac_tokenizer,
471
  device=device,
472
- max_len=gen_max_len,
473
  sos_idx=sos_idx,
474
  eos_idx=eos_idx,
475
  pad_idx=pad_idx,
476
  )
477
- logging.info(f"Prediction returned: {predicted_name}")
478
 
479
  # --- Format Output ---
480
- if "[Error]" in predicted_name: # Check for error messages from translate
 
 
481
  output_text = (
482
- f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
 
 
 
 
 
 
 
 
483
  )
484
- elif not predicted_name:
485
- output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
486
  else:
487
  output_text = (
488
- f"Input SMILES: {smiles_input}\n\n"
 
489
  f"Predicted IUPAC Name (Greedy Decode):\n"
490
  f"{predicted_name}"
491
  )
492
- return output_text
 
 
493
 
494
- except RuntimeError as e:
495
- logging.error(f"Runtime error during translation: {e}", exc_info=True)
496
- return f"Error: {error_msg}" # Return single error string
497
 
498
  except Exception as e:
499
- logging.error(f"Unexpected error during translation: {e}", exc_info=True)
500
- error_msg = f"Unexpected Error during translation: {e}"
501
- return f"Error: {error_msg}" # Return single error string
 
502
 
503
 
504
  # --- Load Model on App Start ---
 
 
505
  try:
506
  load_model_and_tokenizers()
507
  except gr.Error as ge:
508
- logging.error(f"Gradio Initialization Error: {ge}")
509
- pass # Allow Gradio to potentially start with an error message
510
  except Exception as e:
511
- logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
512
- # Optionally raise gr.Error here too
513
 
514
 
515
- # --- Create Gradio Interface (Simplified) ---
516
  title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
517
  description = f"""
518
  Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
519
  Translation uses **greedy decoding** (picks the most likely next word at each step).
520
- **Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
521
  """
522
-
523
-
524
-
525
- # Replace your Interface code with this:
 
 
 
 
526
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
527
  gr.Markdown(f"# {title}")
528
  gr.Markdown(description)
529
-
530
  with gr.Row():
531
- with gr.Column():
532
  smiles_input = gr.Textbox(
533
  label="SMILES String",
534
- placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
535
- lines=1,
536
  )
537
-
538
- # Add a button for manual triggering
539
- submit_btn = gr.Button("Translate")
540
-
541
- with gr.Column():
542
  output_text = gr.Textbox(
543
- label="Predicted IUPAC Name",
544
- lines=3,
545
  show_copy_button=True,
 
546
  )
547
-
548
- # Connect the events properly
549
- submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
550
-
551
- # If you still want change event (auto-translation as user types)
552
- smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
553
- # Output component
554
- output_text = gr.Textbox(
555
- label="Predicted IUPAC Name",
556
- lines=3,
557
- show_copy_button=True, # Reduced lines slightly
558
- )
 
 
 
 
 
 
 
 
559
 
560
 
561
- # Create the interface instance
562
- iface = gr.Interface(
563
- fn=predict_iupac, # The function to call
564
- inputs=smiles_input, # Input component directly
565
- outputs=output_text, # Output component
566
- title=title,
567
- description=description,
568
- allow_flagging="never",
569
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
570
- )
571
  # --- Launch the App ---
572
  if __name__ == "__main__":
573
- iface.launch()
 
 
 
1
  # app.py
2
  import gradio as gr
3
  import torch
 
4
  # import torch.nn.functional as F # No longer needed for greedy decode directly
5
  import pytorch_lightning as pl
6
  import os
 
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
  import gc
12
+ try:
13
+ from rdkit import Chem
14
+ from rdkit import RDLogger # Optional: To suppress RDKit logs
15
+ RDLogger.DisableLog('rdApp.*') # Suppress RDKit warnings/errors if desired
16
+ except ImportError:
17
+ logging.warning("RDKit not found. SMILES canonicalization will be skipped. Install with 'pip install rdkit'")
18
+ Chem = None # Set Chem to None if RDKit is not available
19
 
20
  # --- Configuration ---
21
  MODEL_REPO_ID = (
 
34
 
35
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
36
  try:
37
+ # Ensure enhanced_trainer.py is in the root directory of your space repo
38
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
 
39
  logging.info("Successfully imported from enhanced_trainer.py.")
40
  except ImportError as e:
41
  logging.error(
42
  f"Failed to import helper code from enhanced_trainer.py: {e}. "
43
  f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
44
  )
45
+ # We raise gr.Error later during loading if the class isn't found
46
+ SmilesIupacLitModule = None
47
+ generate_square_subsequent_mask = None
48
  except Exception as e:
49
  logging.error(
50
  f"An unexpected error occurred during helper code import: {e}", exc_info=True
51
  )
52
+ SmilesIupacLitModule = None
53
+ generate_square_subsequent_mask = None
54
+
55
 
56
  # --- Global Variables (Load Model Once) ---
57
  model: pl.LightningModule | None = None
 
73
  ) -> torch.Tensor:
74
  """
75
  Performs greedy decoding using the LightningModule's model.
76
+ Assumes model has 'model.encode', 'model.decode', 'model.generator' attributes.
77
  """
78
+ if not hasattr(model, 'model') or not hasattr(model.model, 'encode') or \
79
+ not hasattr(model.model, 'decode') or not hasattr(model.model, 'generator'):
80
+ logging.error("Model object does not have the expected 'model.encode/decode/generator' structure.")
81
+ raise AttributeError("Model structure mismatch for greedy decoding.")
82
+
83
+ if generate_square_subsequent_mask is None:
84
+ logging.error("generate_square_subsequent_mask function not imported.")
85
+ raise ImportError("generate_square_subsequent_mask is required for greedy_decode.")
86
+
87
+
88
  model.eval() # Ensure model is in evaluation mode
89
  transformer_model = model.model # Access the underlying Seq2SeqTransformer
90
 
91
  try:
92
  with torch.no_grad():
93
  # --- Encode Source ---
94
+ # The mask should be True where the input *is* padding.
95
  memory = transformer_model.encode(
96
+ src, src_mask=None # Standard transformer encoder doesn't usually use src_mask
97
+ ) # [1, src_len, emb_size] if batch_first=True in TransformerEncoder
98
+ # If batch_first=False (default): [src_len, 1, emb_size] -> adjust usage below
99
+ # Assuming batch_first=False for standard nn.Transformer
100
  memory = memory.to(device)
101
+
102
+ # Memory key padding mask needs to be [batch_size, src_len] -> [1, src_len]
103
+ memory_key_padding_mask = src_padding_mask.to(device) # [1, src_len]
104
 
105
  # --- Initialize Target Sequence ---
106
+ # Start with the SOS token -> Shape [1, 1] (batch, seq)
107
  ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
108
  sos_idx
109
+ )
110
 
111
  # --- Decoding Loop ---
112
+ for i in range(max_len - 1): # Max length limit
113
+ # Target processing depends on whether decoder expects batch_first
114
+ # Standard nn.TransformerDecoder expects [tgt_len, batch_size, emb_size]
115
+ # Standard nn.TransformerDecoder expects tgt_mask [tgt_len, tgt_len]
116
+ # Standard nn.TransformerDecoder expects memory_key_padding_mask [batch_size, src_len]
117
+ # Standard nn.TransformerDecoder expects tgt_key_padding_mask [batch_size, tgt_len]
118
+
119
  tgt_seq_len = ys.shape[1]
120
+ # Create causal mask -> [tgt_len, tgt_len]
121
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
122
  device
123
+ )
124
+
125
+ # Target padding mask: False for non-pad tokens -> [1, tgt_len]
126
  tgt_padding_mask = torch.zeros(
127
  ys.shape, dtype=torch.bool, device=device
128
+ )
129
+
130
+ # Prepare target for decoder (assuming batch_first=False expected)
131
+ # Input ys is [1, current_len] -> need [current_len, 1]
132
+ ys_decoder_input = ys.transpose(0, 1).to(device) # [current_len, 1]
133
 
134
  # Decode one step
135
  decoder_output = transformer_model.decode(
136
+ tgt=ys_decoder_input, # [current_len, 1]
137
+ memory=memory, # [src_len, 1, emb_size]
138
+ tgt_mask=tgt_mask, # [current_len, current_len]
139
+ tgt_key_padding_mask=tgt_padding_mask, # [1, current_len]
140
+ memory_key_padding_mask=memory_key_padding_mask, # [1, src_len]
141
+ ) # Output shape [current_len, 1, emb_size]
142
 
143
  # Get logits for the *next* token prediction
144
+ # Use output corresponding to the last input token -> [-1, :, :]
145
  next_token_logits = transformer_model.generator(
146
+ decoder_output[-1, :, :] # Shape [1, emb_size]
147
+ ) # Output shape [1, tgt_vocab_size]
 
 
148
 
149
  # Find the most likely next token (greedy choice)
150
+ next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # Shape [1]
 
 
151
  next_word_id = next_word_id_tensor.item()
152
 
153
+ # Append the chosen token to the sequence (shape [1, 1])
154
+ next_word_tensor = torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id)
155
+
156
+ # Concatenate along the sequence dimension (dim=1)
157
+ ys = torch.cat([ys, next_word_tensor], dim=1) # [1, current_len + 1]
 
 
 
 
 
158
 
159
  # Stop if EOS token is generated
160
  if next_word_id == eos_idx:
161
  break
162
 
163
  # Return the generated sequence (excluding the initial SOS token)
164
+ # Shape [1, generated_len]
165
+ return ys[:, 1:]
166
 
167
  except RuntimeError as e:
168
  logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
169
  if "CUDA out of memory" in str(e) and device.type == "cuda":
170
  gc.collect()
171
  torch.cuda.empty_cache()
172
+ raise RuntimeError("CUDA out of memory during greedy decoding.") # Re-raise specific error
173
+ raise e # Re-raise other runtime errors
 
174
  except Exception as e:
175
  logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
176
+ raise e # Re-raise
177
 
178
 
179
  # --- Translation Function (Using Greedy Decode) ---
 
183
  smiles_tokenizer: Tokenizer,
184
  iupac_tokenizer: Tokenizer,
185
  device: torch.device,
186
+ max_len_config: int, # Max length from config (used for source truncation & generation limit)
187
  sos_idx: int,
188
  eos_idx: int,
189
  pad_idx: int,
190
+ ) -> str: # Returns a single string or an error message
191
  """
192
  Translates a single SMILES string using greedy decoding.
193
  """
194
+ if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
195
+ return "[Initialization Error: Components not loaded]"
196
+
197
  model.eval() # Ensure model is in eval mode
198
 
199
  # --- Tokenize Source ---
200
  try:
201
+ # Ensure tokenizer has truncation configured
202
+ smiles_tokenizer.enable_truncation(max_length=max_len_config)
203
+ smiles_tokenizer.enable_padding(pad_id=pad_idx, pad_token="<pad>", length=max_len_config) # Ensure padding for consistent input length if needed by model
204
+
205
  src_encoded = smiles_tokenizer.encode(src_sentence)
206
  if not src_encoded or not src_encoded.ids:
207
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
208
+ return "[Encoding Error: Empty result]"
 
209
  src_ids = src_encoded.ids
210
+ # Use attention mask directly for padding mask (1 for real tokens, 0 for padding)
211
+ # We need the opposite for PyTorch Transformer (True for padding, False for real)
212
+ src_attention_mask = torch.tensor(src_encoded.attention_mask, dtype=torch.long)
213
+ src_padding_mask = (src_attention_mask == 0) # True where it's padded
214
+
215
  except Exception as e:
216
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
217
+ return f"[Encoding Error: {e}]"
218
 
219
  # --- Prepare Input Tensor and Mask ---
220
+ # Input tensor shape [1, src_len]
221
+ src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
222
+ # Padding mask shape [1, src_len]
223
+ src_padding_mask = src_padding_mask.unsqueeze(0).to(device)
 
224
 
225
  # --- Perform Greedy Decoding ---
226
+ try:
227
+ tgt_tokens_tensor = greedy_decode(
228
+ model=model,
229
+ src=src,
230
+ src_padding_mask=src_padding_mask,
231
+ max_len=max_len_config, # Use config max_len as generation limit
232
+ sos_idx=sos_idx,
233
+ eos_idx=eos_idx,
234
+ device=device,
235
+ ) # Returns a single tensor [1, generated_len]
236
+
237
+ except (RuntimeError, AttributeError, ImportError, Exception) as e:
238
+ logging.error(f"Error during greedy_decode call: {e}", exc_info=True)
239
+ return f"[Decoding Error: {e}]"
240
+
241
 
242
  # --- Decode Generated Tokens ---
243
  if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
244
+ # Check if the source itself was just padding or EOS
245
+ if len(src_ids) <= 2 and all(t in [pad_idx, eos_idx, sos_idx] for t in src_ids): # Rough check
246
+ logging.warning(f"Input SMILES '{src_sentence}' resulted in very short/empty encoding, leading to empty decode.")
247
+ return "[Decoding Warning: Input potentially too short or invalid after tokenization]"
248
+ else:
249
+ logging.warning(
250
+ f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
251
+ )
252
+ return "[Decoding Error: Empty Output]"
253
 
254
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
255
  try:
256
  # Decode using the target tokenizer, skipping special tokens
257
  translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
258
+ return translation.strip() # Strip leading/trailing whitespace
259
  except Exception as e:
260
  logging.error(
261
  f"Error decoding target tokens {tgt_tokens}: {e}",
262
  exc_info=True,
263
  )
264
+ return "[Decoding Error: Tokenizer failed]"
265
 
266
 
267
+ # --- Model/Tokenizer Loading Function ---
268
  def load_model_and_tokenizers():
269
  """Loads tokenizers, config, and model from Hugging Face Hub."""
270
+ global model, smiles_tokenizer, iupac_tokenizer, device, config, SmilesIupacLitModule, generate_square_subsequent_mask
271
  if model is not None: # Already loaded
272
  logging.info("Model and tokenizers already loaded.")
273
  return
274
 
275
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
276
+
277
+ # --- Check if helper code loaded ---
278
+ if SmilesIupacLitModule is None or generate_square_subsequent_mask is None:
279
+ error_msg = f"Initialization Error: Could not load required components from enhanced_trainer.py. Check Space logs and ensure the file exists in the repo root."
280
+ logging.error(error_msg)
281
+ raise gr.Error(error_msg)
282
+
283
  try:
284
  # Determine device
285
  if torch.cuda.is_available():
286
+ device = torch.device("cuda")
287
+ logging.info("CUDA available, using GPU.")
 
 
 
 
288
  else:
289
  device = torch.device("cpu")
290
  logging.info("CUDA not available, using CPU.")
 
292
  # Download files
293
  logging.info("Downloading files from Hugging Face Hub...")
294
  try:
295
+ # Define cache directory, default to './hf_cache' if GRADIO_CACHE is not set
296
  cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
297
+ os.makedirs(cache_dir, exist_ok=True) # Ensure cache dir exists
298
  logging.info(f"Using cache directory: {cache_dir}")
299
 
300
+ # Download files to the specified cache directory
301
  checkpoint_path = hf_hub_download(
302
+ repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir, force_download=False # Avoid re-download if files exist
303
  )
304
  smiles_tokenizer_path = hf_hub_download(
305
+ repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir, force_download=False
 
 
306
  )
307
  iupac_tokenizer_path = hf_hub_download(
308
+ repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir, force_download=False
 
 
309
  )
310
  config_path = hf_hub_download(
311
+ repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir, force_download=False
312
  )
313
+ logging.info("Files downloaded (or found in cache) successfully.")
314
  except Exception as e:
315
  logging.error(
316
  f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
 
328
  logging.info("Configuration loaded.")
329
  # --- Validate essential config keys ---
330
  required_keys = [
331
+ "src_vocab_size",
332
+ "tgt_vocab_size",
333
  "emb_size",
334
  "nhead",
335
  "ffn_hid_dim",
336
  "num_encoder_layers",
337
  "num_decoder_layers",
338
  "dropout",
339
+ "max_len", # Crucial for tokenization and generation limit
340
  "pad_token_id",
341
  "bos_token_id",
342
  "eos_token_id",
343
  ]
344
+
345
+ # Check for alternative key names if needed (adjust if your config uses different names)
346
+ config['src_vocab_size'] = config.get('src_vocab_size', config.get('SRC_VOCAB_SIZE'))
347
+ config['tgt_vocab_size'] = config.get('tgt_vocab_size', config.get('TGT_VOCAB_SIZE'))
348
+ config['emb_size'] = config.get('emb_size', config.get('EMB_SIZE'))
349
+ config['nhead'] = config.get('nhead', config.get('NHEAD'))
350
+ config['ffn_hid_dim'] = config.get('ffn_hid_dim', config.get('FFN_HID_DIM'))
351
+ config['num_encoder_layers'] = config.get('num_encoder_layers', config.get('NUM_ENCODER_LAYERS'))
352
+ config['num_decoder_layers'] = config.get('num_decoder_layers', config.get('NUM_DECODER_LAYERS'))
353
+ config['dropout'] = config.get('dropout', config.get('DROPOUT'))
354
+ config['max_len'] = config.get('max_len', config.get('MAX_LEN'))
355
+ config['pad_token_id'] = config.get('pad_token_id', config.get('PAD_IDX'))
356
+ config['bos_token_id'] = config.get('bos_token_id', config.get('BOS_IDX'))
357
+ config['eos_token_id'] = config.get('eos_token_id', config.get('EOS_IDX'))
358
+ # Add UNK if needed by your model/tokenizer setup
359
+ # config['unk_token_id'] = config.get('unk_token_id', config.get('UNK_IDX', 0)) # Default to 0 if missing? Risky.
360
 
361
  missing_keys = [key for key in required_keys if config.get(key) is None]
362
  if missing_keys:
 
366
  )
367
 
368
  logging.info(
369
+ f"Using config: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
370
  f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
371
  f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
372
  )
 
391
  try:
392
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
393
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
394
+
395
+ # --- Validate Tokenizer Special Tokens Against Config ---
 
396
  pad_token = "<pad>"
397
  sos_token = "<sos>"
398
  eos_token = "<eos>"
399
+ unk_token = "<unk>" # Assuming standard UNK token
400
  issues = []
401
+
402
+ # SMILES Tokenizer Checks
403
+ smiles_pad_id = smiles_tokenizer.token_to_id(pad_token)
404
+ smiles_unk_id = smiles_tokenizer.token_to_id(unk_token)
405
+ if smiles_pad_id is None or smiles_pad_id != config["pad_token_id"]:
406
+ issues.append(f"SMILES PAD ID mismatch (Tokenizer: {smiles_pad_id}, Config: {config['pad_token_id']})")
407
+ if smiles_unk_id is None:
408
+ issues.append("SMILES UNK token not found in tokenizer")
409
+
410
+ # IUPAC Tokenizer Checks
411
+ iupac_pad_id = iupac_tokenizer.token_to_id(pad_token)
412
+ iupac_sos_id = iupac_tokenizer.token_to_id(sos_token)
413
+ iupac_eos_id = iupac_tokenizer.token_to_id(eos_token)
414
+ iupac_unk_id = iupac_tokenizer.token_to_id(unk_token)
415
+ if iupac_pad_id is None or iupac_pad_id != config["pad_token_id"]:
416
+ issues.append(f"IUPAC PAD ID mismatch (Tokenizer: {iupac_pad_id}, Config: {config['pad_token_id']})")
417
+ if iupac_sos_id is None or iupac_sos_id != config["bos_token_id"]:
418
+ issues.append(f"IUPAC SOS ID mismatch (Tokenizer: {iupac_sos_id}, Config: {config['bos_token_id']})")
419
+ if iupac_eos_id is None or iupac_eos_id != config["eos_token_id"]:
420
+ issues.append(f"IUPAC EOS ID mismatch (Tokenizer: {iupac_eos_id}, Config: {config['eos_token_id']})")
421
+ if iupac_unk_id is None:
422
+ issues.append("IUPAC UNK token not found in tokenizer")
423
+
424
  if issues:
425
+ logging.warning("Tokenizer validation issues detected: \n - " + "\n - ".join(issues))
426
+ # Decide if this is critical. For inference, SOS/EOS/PAD matches are most important.
427
+ # raise gr.Error("Tokenizer Validation Error: Mismatch between config and tokenizer files. Check logs.")
428
+ else:
429
+ logging.info("Tokenizers loaded and special tokens validated against config.")
430
 
431
  except Exception as e:
432
  logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
433
  raise gr.Error(
434
+ f"Tokenizer Error: Could not load tokenizers. Check logs and file paths. Error: {e}"
435
  )
436
 
437
  # Load model
438
  logging.info("Loading model from checkpoint...")
439
  try:
440
+ # Instantiate the LightningModule using hyperparameters from config
441
+ # Make sure SmilesIupacLitModule's __init__ accepts these keys
442
+ model_instance = SmilesIupacLitModule(**config)
443
+
444
+ # Load the state dict from the checkpoint onto the instance
445
+ # Use load_state_dict for more control if load_from_checkpoint causes issues
446
+ # state_dict = torch.load(checkpoint_path, map_location=device)['state_dict']
447
+ # model_instance.load_state_dict(state_dict, strict=True) # Try strict=False if needed
448
+
449
+ # Use load_from_checkpoint (simpler if it works)
450
  model = SmilesIupacLitModule.load_from_checkpoint(
451
  checkpoint_path,
 
452
  map_location=device,
453
+ # Pass hparams again ONLY if they are needed by load_from_checkpoint specifically
454
+ # and not just by __init__. Usually, instantiating first is cleaner.
455
+ **config, # Try removing this if you instantiate first
456
+ strict=True # Start strict, set to False ONLY if necessary and you understand why
457
  )
458
 
459
+
460
  model.to(device)
461
  model.eval()
462
+ model.freeze() # Freeze weights for inference
463
  logging.info(
464
  f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
465
  )
 
467
  except FileNotFoundError:
468
  logging.error(f"Checkpoint file not found: {checkpoint_path}")
469
  raise gr.Error(
470
+ f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found at expected path."
471
  )
472
+ except RuntimeError as e: # Catch specific runtime errors like size mismatch, OOM
473
+ logging.error(f"Runtime error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
474
+ if "size mismatch" in str(e):
475
+ error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint structure. Or config doesn't match model definition."
476
+ logging.error(error_detail)
477
+ raise gr.Error(f"Model Load Error: {error_detail} Original error: {e}")
478
+ elif "CUDA out of memory" in str(e) or "memory" in str(e).lower():
479
+ logging.warning("Potential OOM error during model loading.")
480
+ gc.collect()
481
+ if device.type == "cuda": torch.cuda.empty_cache()
482
+ raise gr.Error(f"Model Load Error: OOM loading model. Check Space resources. Error: {e}")
483
+ else:
484
+ raise gr.Error(f"Model Load Error: Runtime error. Check logs. Error: {e}")
485
+ except Exception as e: # Catch other potential errors during loading
486
  logging.error(
487
+ f"Unexpected error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
488
+ )
489
+ raise gr.Error(
490
+ f"Model Load Error: Failed to load checkpoint for unknown reason. Check logs. Error: {e}"
491
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
+ except gr.Error as ge: # Catch Gradio-specific errors raised above
494
+ raise ge # Re-raise to stop app launch correctly
495
+ except Exception as e: # Catch any other unexpected errors during the whole process
496
+ logging.error(f"Unexpected error during loading process: {e}", exc_info=True)
497
  raise gr.Error(
498
+ f"Initialization Error: Unexpected failure during setup. Check logs. Error: {e}"
499
  )
500
 
501
 
502
+ # --- Inference Function for Gradio ---
503
  def predict_iupac(smiles_string):
504
  """
505
  Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
506
+ Handles input validation, canonicalization, translation, and output formatting.
507
  """
 
 
 
 
 
508
  global model, smiles_tokenizer, iupac_tokenizer, device, config
509
 
510
+ # --- Check Initialization ---
511
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
512
+ error_msg = "Error: Model or tokenizers not loaded. App initialization failed. Check Space logs."
513
  logging.error(error_msg)
514
+ # Return the error directly in the output box
515
+ return error_msg
516
 
517
+ # --- Input Validation ---
518
  if not smiles_string or not smiles_string.strip():
519
+ return "Error: Please enter a valid SMILES string."
 
520
 
521
+ smiles_input_raw = smiles_string.strip()
522
 
523
+ # --- Canonicalize SMILES (Optional but Recommended) ---
524
+ smiles_input_canon = smiles_input_raw
525
+ if Chem:
526
+ try:
527
+ mol = Chem.MolFromSmiles(smiles_input_raw)
528
+ if mol:
529
+ smiles_input_canon = Chem.MolToSmiles(mol, canonical=True)
530
+ logging.info(f"Canonicalized SMILES: {smiles_input_raw} -> {smiles_input_canon}")
531
+ else:
532
+ # RDKit couldn't parse it, proceed with raw input but warn
533
+ logging.warning(f"Could not parse SMILES '{smiles_input_raw}' with RDKit. Using raw input.")
534
+ # Optionally return an error here if strict parsing is needed
535
+ # return f"Error: Invalid SMILES string '{smiles_input_raw}' according to RDKit."
536
+ except Exception as e:
537
+ logging.error(f"Error during RDKit canonicalization for '{smiles_input_raw}': {e}", exc_info=True)
538
+ # Proceed with raw input, maybe add note to output
539
+ # return f"Error: RDKit processing failed: {e}" # Option to fail hard
540
+
541
+ # --- Translation ---
542
  try:
 
543
  sos_idx = config["bos_token_id"]
544
  eos_idx = config["eos_token_id"]
545
  pad_idx = config["pad_token_id"]
546
+ gen_max_len = config["max_len"] # Use max_len from config
547
 
548
+ predicted_name = translate(
549
  model=model,
550
+ src_sentence=smiles_input_canon, # Use canonicalized SMILES
551
  smiles_tokenizer=smiles_tokenizer,
552
  iupac_tokenizer=iupac_tokenizer,
553
  device=device,
554
+ max_len_config=gen_max_len,
555
  sos_idx=sos_idx,
556
  eos_idx=eos_idx,
557
  pad_idx=pad_idx,
558
  )
559
+ logging.info(f"SMILES: '{smiles_input_canon}', Prediction: '{predicted_name}'")
560
 
561
  # --- Format Output ---
562
+ # Check if translate returned an error message
563
+ if predicted_name.startswith("[") and predicted_name.endswith("]"):
564
+ # Assume it's an error/warning message from translate()
565
  output_text = (
566
+ f"Input SMILES: {smiles_input_canon}\n"
567
+ f"(Raw Input: {smiles_input_raw})\n\n" # Show raw if canonicalization happened
568
+ f"Prediction Failed: {predicted_name}"
569
+ )
570
+ elif not predicted_name: # Handle empty string case
571
+ output_text = (
572
+ f"Input SMILES: {smiles_input_canon}\n"
573
+ f"(Raw Input: {smiles_input_raw})\n\n"
574
+ f"Prediction: [No name generated]"
575
  )
 
 
576
  else:
577
  output_text = (
578
+ f"Input SMILES: {smiles_input_canon}\n"
579
+ f"(Raw Input: {smiles_input_raw})\n\n"
580
  f"Predicted IUPAC Name (Greedy Decode):\n"
581
  f"{predicted_name}"
582
  )
583
+ # Remove the "(Raw Input...)" line if canonicalization didn't change the input
584
+ if smiles_input_raw == smiles_input_canon:
585
+ output_text = output_text.replace(f"(Raw Input: {smiles_input_raw})\n", "")
586
 
587
+ return output_text.strip()
 
 
588
 
589
  except Exception as e:
590
+ # Catch-all for unexpected errors during the prediction process
591
+ logging.error(f"Unexpected error during prediction for '{smiles_input_canon}': {e}", exc_info=True)
592
+ error_msg = f"Error: An unexpected error occurred during translation: {e}"
593
+ return error_msg
594
 
595
 
596
  # --- Load Model on App Start ---
597
+ # Wrap in try/except to allow Gradio UI to potentially display an error
598
+ model_load_error = None
599
  try:
600
  load_model_and_tokenizers()
601
  except gr.Error as ge:
602
+ logging.error(f"Gradio Initialization Error during load: {ge}")
603
+ model_load_error = str(ge) # Store error message
604
  except Exception as e:
605
+ logging.critical(f"CRITICAL error during initial model loading: {e}", exc_info=True)
606
+ model_load_error = f"Critical Error: {e}. Check Space logs."
607
 
608
 
609
+ # --- Create Gradio Interface ---
610
  title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
611
  description = f"""
612
  Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
613
  Translation uses **greedy decoding** (picks the most likely next word at each step).
 
614
  """
615
+ if model_load_error:
616
+ description += f"\n\n**WARNING: Failed to load model or components.**\nReason: {model_load_error}\nFunctionality will be limited."
617
+ elif device:
618
+ description += f"\n**Note:** Model loaded on **{str(device).upper()}**. Check `config.json` in the repo for model details."
619
+ else:
620
+ description += f"\n**Note:** Device information unavailable (loading might have failed)."
621
+
622
+ # Use gr.Blocks for layout
623
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
624
  gr.Markdown(f"# {title}")
625
  gr.Markdown(description)
626
+
627
  with gr.Row():
628
+ with gr.Column(scale=1): # Input column takes less space
629
  smiles_input = gr.Textbox(
630
  label="SMILES String",
631
+ placeholder="Enter SMILES string (e.g., CCO or c1ccccc1)",
632
+ lines=2, # Slightly more lines for longer SMILES
633
  )
634
+ submit_btn = gr.Button("Translate", variant="primary")
635
+
636
+ with gr.Column(scale=2): # Output column takes more space
 
 
637
  output_text = gr.Textbox(
638
+ label="Result",
639
+ lines=5, # More lines for formatted output
640
  show_copy_button=True,
641
+ interactive=False, # Output box is not for user input
642
  )
643
+
644
+ # Define examples
645
+ gr.Examples(
646
+ examples=[
647
+ "CCO",
648
+ "C1=CC=C(C=C1)C(=O)O", # Benzoic acid
649
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # Ibuprofen
650
+ "INVALID_SMILES",
651
+ "ClC(Cl)(Cl)C1=CC=C(C=C1)C(C2=CC=C(Cl)C=C2)C(Cl)(Cl)Cl", # DDT
652
+ ],
653
+ inputs=smiles_input,
654
+ outputs=output_text,
655
+ fn=predict_iupac, # Function to run for examples
656
+ cache_examples=False, # Re-run examples each time if needed, True might speed up demo loading
657
+ )
658
+
659
+ # Connect the button click and input change events
660
+ submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text, api_name="translate_smiles")
661
+ # Optionally trigger on text change (can be slow/resource intensive)
662
+ # smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
663
 
664
 
 
 
 
 
 
 
 
 
 
 
665
  # --- Launch the App ---
666
  if __name__ == "__main__":
667
+ # Set share=True to get a public link (useful for testing)
668
+ # Set debug=True for more detailed Gradio errors during development
669
+ iface.launch(share=False, debug=False)