AdrianM0 commited on
Commit
b3f7b39
·
verified ·
1 Parent(s): 01fa093

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -331
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
 
4
  # import torch.nn.functional as F # No longer needed for greedy decode directly
5
  import pytorch_lightning as pl
6
  import os
@@ -9,13 +10,8 @@ import logging
9
  from tokenizers import Tokenizer
10
  from huggingface_hub import hf_hub_download
11
  import gc
12
- try:
13
- from rdkit import Chem
14
- from rdkit import RDLogger # Optional: To suppress RDKit logs
15
- RDLogger.DisableLog('rdApp.*') # Suppress RDKit warnings/errors if desired
16
- except ImportError:
17
- logging.warning("RDKit not found. SMILES canonicalization will be skipped. Install with 'pip install rdkit'")
18
- Chem = None # Set Chem to None if RDKit is not available
19
 
20
  # --- Configuration ---
21
  MODEL_REPO_ID = (
@@ -34,24 +30,24 @@ logging.basicConfig(
34
 
35
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
36
  try:
37
- # Ensure enhanced_trainer.py is in the root directory of your space repo
38
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
 
39
  logging.info("Successfully imported from enhanced_trainer.py.")
40
  except ImportError as e:
41
  logging.error(
42
  f"Failed to import helper code from enhanced_trainer.py: {e}. "
43
  f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
44
  )
45
- # We raise gr.Error later during loading if the class isn't found
46
- SmilesIupacLitModule = None
47
- generate_square_subsequent_mask = None
48
  except Exception as e:
49
  logging.error(
50
  f"An unexpected error occurred during helper code import: {e}", exc_info=True
51
  )
52
- SmilesIupacLitModule = None
53
- generate_square_subsequent_mask = None
54
-
55
 
56
  # --- Global Variables (Load Model Once) ---
57
  model: pl.LightningModule | None = None
@@ -73,107 +69,87 @@ def greedy_decode(
73
  ) -> torch.Tensor:
74
  """
75
  Performs greedy decoding using the LightningModule's model.
76
- Assumes model has 'model.encode', 'model.decode', 'model.generator' attributes.
77
  """
78
- if not hasattr(model, 'model') or not hasattr(model.model, 'encode') or \
79
- not hasattr(model.model, 'decode') or not hasattr(model.model, 'generator'):
80
- logging.error("Model object does not have the expected 'model.encode/decode/generator' structure.")
81
- raise AttributeError("Model structure mismatch for greedy decoding.")
82
-
83
- if generate_square_subsequent_mask is None:
84
- logging.error("generate_square_subsequent_mask function not imported.")
85
- raise ImportError("generate_square_subsequent_mask is required for greedy_decode.")
86
-
87
-
88
  model.eval() # Ensure model is in evaluation mode
89
  transformer_model = model.model # Access the underlying Seq2SeqTransformer
90
 
91
  try:
92
  with torch.no_grad():
93
  # --- Encode Source ---
94
- # The mask should be True where the input *is* padding.
95
  memory = transformer_model.encode(
96
- src, src_mask=None # Standard transformer encoder doesn't usually use src_mask
97
- ) # [1, src_len, emb_size] if batch_first=True in TransformerEncoder
98
- # If batch_first=False (default): [src_len, 1, emb_size] -> adjust usage below
99
- # Assuming batch_first=False for standard nn.Transformer
100
  memory = memory.to(device)
101
-
102
- # Memory key padding mask needs to be [batch_size, src_len] -> [1, src_len]
103
- memory_key_padding_mask = src_padding_mask.to(device) # [1, src_len]
104
 
105
  # --- Initialize Target Sequence ---
106
- # Start with the SOS token -> Shape [1, 1] (batch, seq)
107
  ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
108
  sos_idx
109
- )
110
 
111
  # --- Decoding Loop ---
112
- for i in range(max_len - 1): # Max length limit
113
- # Target processing depends on whether decoder expects batch_first
114
- # Standard nn.TransformerDecoder expects [tgt_len, batch_size, emb_size]
115
- # Standard nn.TransformerDecoder expects tgt_mask [tgt_len, tgt_len]
116
- # Standard nn.TransformerDecoder expects memory_key_padding_mask [batch_size, src_len]
117
- # Standard nn.TransformerDecoder expects tgt_key_padding_mask [batch_size, tgt_len]
118
-
119
  tgt_seq_len = ys.shape[1]
120
- # Create causal mask -> [tgt_len, tgt_len]
121
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
122
  device
123
- )
124
-
125
- # Target padding mask: False for non-pad tokens -> [1, tgt_len]
126
  tgt_padding_mask = torch.zeros(
127
  ys.shape, dtype=torch.bool, device=device
128
- )
129
-
130
- # Prepare target for decoder (assuming batch_first=False expected)
131
- # Input ys is [1, current_len] -> need [current_len, 1]
132
- ys_decoder_input = ys.transpose(0, 1).to(device) # [current_len, 1]
133
 
134
  # Decode one step
135
  decoder_output = transformer_model.decode(
136
- tgt=ys_decoder_input, # [current_len, 1]
137
- memory=memory, # [src_len, 1, emb_size]
138
- tgt_mask=tgt_mask, # [current_len, current_len]
139
- tgt_key_padding_mask=tgt_padding_mask, # [1, current_len]
140
- memory_key_padding_mask=memory_key_padding_mask, # [1, src_len]
141
- ) # Output shape [current_len, 1, emb_size]
142
 
143
  # Get logits for the *next* token prediction
144
- # Use output corresponding to the last input token -> [-1, :, :]
145
  next_token_logits = transformer_model.generator(
146
- decoder_output[-1, :, :] # Shape [1, emb_size]
147
- ) # Output shape [1, tgt_vocab_size]
 
 
148
 
149
  # Find the most likely next token (greedy choice)
150
- next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # Shape [1]
 
 
151
  next_word_id = next_word_id_tensor.item()
152
 
153
- # Append the chosen token to the sequence (shape [1, 1])
154
- next_word_tensor = torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id)
155
-
156
- # Concatenate along the sequence dimension (dim=1)
157
- ys = torch.cat([ys, next_word_tensor], dim=1) # [1, current_len + 1]
 
 
 
 
 
158
 
159
  # Stop if EOS token is generated
160
  if next_word_id == eos_idx:
161
  break
162
 
163
  # Return the generated sequence (excluding the initial SOS token)
164
- # Shape [1, generated_len]
165
- return ys[:, 1:]
166
 
167
  except RuntimeError as e:
168
  logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
169
  if "CUDA out of memory" in str(e) and device.type == "cuda":
170
  gc.collect()
171
  torch.cuda.empty_cache()
172
- raise RuntimeError("CUDA out of memory during greedy decoding.") # Re-raise specific error
173
- raise e # Re-raise other runtime errors
 
174
  except Exception as e:
175
  logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
176
- raise e # Re-raise
177
 
178
 
179
  # --- Translation Function (Using Greedy Decode) ---
@@ -183,108 +159,94 @@ def translate(
183
  smiles_tokenizer: Tokenizer,
184
  iupac_tokenizer: Tokenizer,
185
  device: torch.device,
186
- max_len_config: int, # Max length from config (used for source truncation & generation limit)
187
  sos_idx: int,
188
  eos_idx: int,
189
  pad_idx: int,
190
- ) -> str: # Returns a single string or an error message
191
  """
192
  Translates a single SMILES string using greedy decoding.
193
  """
194
- if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
195
- return "[Initialization Error: Components not loaded]"
196
-
197
  model.eval() # Ensure model is in eval mode
198
 
199
  # --- Tokenize Source ---
200
  try:
201
- # Ensure tokenizer has truncation configured
202
- smiles_tokenizer.enable_truncation(max_length=max_len_config)
203
- smiles_tokenizer.enable_padding(pad_id=pad_idx, pad_token="<pad>", length=max_len_config) # Ensure padding for consistent input length if needed by model
204
-
205
  src_encoded = smiles_tokenizer.encode(src_sentence)
206
  if not src_encoded or not src_encoded.ids:
207
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
208
- return "[Encoding Error: Empty result]"
 
209
  src_ids = src_encoded.ids
210
- # Use attention mask directly for padding mask (1 for real tokens, 0 for padding)
211
- # We need the opposite for PyTorch Transformer (True for padding, False for real)
212
- src_attention_mask = torch.tensor(src_encoded.attention_mask, dtype=torch.long)
213
- src_padding_mask = (src_attention_mask == 0) # True where it's padded
214
-
215
  except Exception as e:
216
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
217
- return f"[Encoding Error: {e}]"
218
 
219
  # --- Prepare Input Tensor and Mask ---
220
- # Input tensor shape [1, src_len]
221
- src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
222
- # Padding mask shape [1, src_len]
223
- src_padding_mask = src_padding_mask.unsqueeze(0).to(device)
 
224
 
225
  # --- Perform Greedy Decoding ---
226
- try:
227
- tgt_tokens_tensor = greedy_decode(
228
- model=model,
229
- src=src,
230
- src_padding_mask=src_padding_mask,
231
- max_len=max_len_config, # Use config max_len as generation limit
232
- sos_idx=sos_idx,
233
- eos_idx=eos_idx,
234
- device=device,
235
- ) # Returns a single tensor [1, generated_len]
236
-
237
- except (RuntimeError, AttributeError, ImportError, Exception) as e:
238
- logging.error(f"Error during greedy_decode call: {e}", exc_info=True)
239
- return f"[Decoding Error: {e}]"
240
-
241
 
242
  # --- Decode Generated Tokens ---
243
  if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
244
- # Check if the source itself was just padding or EOS
245
- if len(src_ids) <= 2 and all(t in [pad_idx, eos_idx, sos_idx] for t in src_ids): # Rough check
246
- logging.warning(f"Input SMILES '{src_sentence}' resulted in very short/empty encoding, leading to empty decode.")
247
- return "[Decoding Warning: Input potentially too short or invalid after tokenization]"
248
- else:
249
- logging.warning(
250
- f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
251
- )
252
- return "[Decoding Error: Empty Output]"
253
 
254
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
255
  try:
256
  # Decode using the target tokenizer, skipping special tokens
257
  translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
258
- return translation.strip() # Strip leading/trailing whitespace
259
  except Exception as e:
260
  logging.error(
261
  f"Error decoding target tokens {tgt_tokens}: {e}",
262
  exc_info=True,
263
  )
264
- return "[Decoding Error: Tokenizer failed]"
265
 
266
 
267
- # --- Model/Tokenizer Loading Function ---
268
  def load_model_and_tokenizers():
269
  """Loads tokenizers, config, and model from Hugging Face Hub."""
270
- global model, smiles_tokenizer, iupac_tokenizer, device, config, SmilesIupacLitModule, generate_square_subsequent_mask
271
  if model is not None: # Already loaded
272
  logging.info("Model and tokenizers already loaded.")
273
  return
274
 
275
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
276
-
277
- # --- Check if helper code loaded ---
278
- if SmilesIupacLitModule is None or generate_square_subsequent_mask is None:
279
- error_msg = f"Initialization Error: Could not load required components from enhanced_trainer.py. Check Space logs and ensure the file exists in the repo root."
280
- logging.error(error_msg)
281
- raise gr.Error(error_msg)
282
-
283
  try:
284
  # Determine device
285
  if torch.cuda.is_available():
286
- device = torch.device("cuda")
287
- logging.info("CUDA available, using GPU.")
 
 
 
 
288
  else:
289
  device = torch.device("cpu")
290
  logging.info("CUDA not available, using CPU.")
@@ -292,25 +254,27 @@ def load_model_and_tokenizers():
292
  # Download files
293
  logging.info("Downloading files from Hugging Face Hub...")
294
  try:
295
- # Define cache directory, default to './hf_cache' if GRADIO_CACHE is not set
296
  cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
297
- os.makedirs(cache_dir, exist_ok=True) # Ensure cache dir exists
298
  logging.info(f"Using cache directory: {cache_dir}")
299
 
300
- # Download files to the specified cache directory
301
  checkpoint_path = hf_hub_download(
302
- repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir, force_download=False # Avoid re-download if files exist
303
  )
304
  smiles_tokenizer_path = hf_hub_download(
305
- repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir, force_download=False
 
 
306
  )
307
  iupac_tokenizer_path = hf_hub_download(
308
- repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir, force_download=False
 
 
309
  )
310
  config_path = hf_hub_download(
311
- repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir, force_download=False
312
  )
313
- logging.info("Files downloaded (or found in cache) successfully.")
314
  except Exception as e:
315
  logging.error(
316
  f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
@@ -328,35 +292,30 @@ def load_model_and_tokenizers():
328
  logging.info("Configuration loaded.")
329
  # --- Validate essential config keys ---
330
  required_keys = [
331
- "src_vocab_size",
332
- "tgt_vocab_size",
333
  "emb_size",
334
  "nhead",
335
  "ffn_hid_dim",
336
  "num_encoder_layers",
337
  "num_decoder_layers",
338
  "dropout",
339
- "max_len", # Crucial for tokenization and generation limit
340
  "pad_token_id",
341
  "bos_token_id",
342
  "eos_token_id",
343
  ]
344
-
345
- # Check for alternative key names if needed (adjust if your config uses different names)
346
- config['src_vocab_size'] = config.get('src_vocab_size', config.get('SRC_VOCAB_SIZE'))
347
- config['tgt_vocab_size'] = config.get('tgt_vocab_size', config.get('TGT_VOCAB_SIZE'))
348
- config['emb_size'] = config.get('emb_size', config.get('EMB_SIZE'))
349
- config['nhead'] = config.get('nhead', config.get('NHEAD'))
350
- config['ffn_hid_dim'] = config.get('ffn_hid_dim', config.get('FFN_HID_DIM'))
351
- config['num_encoder_layers'] = config.get('num_encoder_layers', config.get('NUM_ENCODER_LAYERS'))
352
- config['num_decoder_layers'] = config.get('num_decoder_layers', config.get('NUM_DECODER_LAYERS'))
353
- config['dropout'] = config.get('dropout', config.get('DROPOUT'))
354
- config['max_len'] = config.get('max_len', config.get('MAX_LEN'))
355
- config['pad_token_id'] = config.get('pad_token_id', config.get('PAD_IDX'))
356
- config['bos_token_id'] = config.get('bos_token_id', config.get('BOS_IDX'))
357
- config['eos_token_id'] = config.get('eos_token_id', config.get('EOS_IDX'))
358
- # Add UNK if needed by your model/tokenizer setup
359
- # config['unk_token_id'] = config.get('unk_token_id', config.get('UNK_IDX', 0)) # Default to 0 if missing? Risky.
360
 
361
  missing_keys = [key for key in required_keys if config.get(key) is None]
362
  if missing_keys:
@@ -366,7 +325,7 @@ def load_model_and_tokenizers():
366
  )
367
 
368
  logging.info(
369
- f"Using config: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
370
  f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
371
  f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
372
  )
@@ -391,75 +350,51 @@ def load_model_and_tokenizers():
391
  try:
392
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
393
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
394
-
395
- # --- Validate Tokenizer Special Tokens Against Config ---
 
396
  pad_token = "<pad>"
397
  sos_token = "<sos>"
398
  eos_token = "<eos>"
399
- unk_token = "<unk>" # Assuming standard UNK token
400
  issues = []
401
-
402
- # SMILES Tokenizer Checks
403
- smiles_pad_id = smiles_tokenizer.token_to_id(pad_token)
404
- smiles_unk_id = smiles_tokenizer.token_to_id(unk_token)
405
- if smiles_pad_id is None or smiles_pad_id != config["pad_token_id"]:
406
- issues.append(f"SMILES PAD ID mismatch (Tokenizer: {smiles_pad_id}, Config: {config['pad_token_id']})")
407
- if smiles_unk_id is None:
408
- issues.append("SMILES UNK token not found in tokenizer")
409
-
410
- # IUPAC Tokenizer Checks
411
- iupac_pad_id = iupac_tokenizer.token_to_id(pad_token)
412
- iupac_sos_id = iupac_tokenizer.token_to_id(sos_token)
413
- iupac_eos_id = iupac_tokenizer.token_to_id(eos_token)
414
- iupac_unk_id = iupac_tokenizer.token_to_id(unk_token)
415
- if iupac_pad_id is None or iupac_pad_id != config["pad_token_id"]:
416
- issues.append(f"IUPAC PAD ID mismatch (Tokenizer: {iupac_pad_id}, Config: {config['pad_token_id']})")
417
- if iupac_sos_id is None or iupac_sos_id != config["bos_token_id"]:
418
- issues.append(f"IUPAC SOS ID mismatch (Tokenizer: {iupac_sos_id}, Config: {config['bos_token_id']})")
419
- if iupac_eos_id is None or iupac_eos_id != config["eos_token_id"]:
420
- issues.append(f"IUPAC EOS ID mismatch (Tokenizer: {iupac_eos_id}, Config: {config['eos_token_id']})")
421
- if iupac_unk_id is None:
422
- issues.append("IUPAC UNK token not found in tokenizer")
423
-
424
  if issues:
425
- logging.warning("Tokenizer validation issues detected: \n - " + "\n - ".join(issues))
426
- # Decide if this is critical. For inference, SOS/EOS/PAD matches are most important.
427
- # raise gr.Error("Tokenizer Validation Error: Mismatch between config and tokenizer files. Check logs.")
428
- else:
429
- logging.info("Tokenizers loaded and special tokens validated against config.")
430
 
431
  except Exception as e:
432
  logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
433
  raise gr.Error(
434
- f"Tokenizer Error: Could not load tokenizers. Check logs and file paths. Error: {e}"
435
  )
436
 
437
  # Load model
438
  logging.info("Loading model from checkpoint...")
439
  try:
440
- # Instantiate the LightningModule using hyperparameters from config
441
- # Make sure SmilesIupacLitModule's __init__ accepts these keys
442
- model_instance = SmilesIupacLitModule(**config)
443
-
444
- # Load the state dict from the checkpoint onto the instance
445
- # Use load_state_dict for more control if load_from_checkpoint causes issues
446
- # state_dict = torch.load(checkpoint_path, map_location=device)['state_dict']
447
- # model_instance.load_state_dict(state_dict, strict=True) # Try strict=False if needed
448
-
449
- # Use load_from_checkpoint (simpler if it works)
450
  model = SmilesIupacLitModule.load_from_checkpoint(
451
  checkpoint_path,
 
452
  map_location=device,
453
- # Pass hparams again ONLY if they are needed by load_from_checkpoint specifically
454
- # and not just by __init__. Usually, instantiating first is cleaner.
455
- **config, # Try removing this if you instantiate first
456
- strict=True # Start strict, set to False ONLY if necessary and you understand why
457
  )
458
 
459
-
460
  model.to(device)
461
  model.eval()
462
- model.freeze() # Freeze weights for inference
463
  logging.info(
464
  f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
465
  )
@@ -467,203 +402,172 @@ def load_model_and_tokenizers():
467
  except FileNotFoundError:
468
  logging.error(f"Checkpoint file not found: {checkpoint_path}")
469
  raise gr.Error(
470
- f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found at expected path."
471
  )
472
- except RuntimeError as e: # Catch specific runtime errors like size mismatch, OOM
473
- logging.error(f"Runtime error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
474
- if "size mismatch" in str(e):
475
- error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint structure. Or config doesn't match model definition."
476
- logging.error(error_detail)
477
- raise gr.Error(f"Model Load Error: {error_detail} Original error: {e}")
478
- elif "CUDA out of memory" in str(e) or "memory" in str(e).lower():
479
- logging.warning("Potential OOM error during model loading.")
480
- gc.collect()
481
- if device.type == "cuda": torch.cuda.empty_cache()
482
- raise gr.Error(f"Model Load Error: OOM loading model. Check Space resources. Error: {e}")
483
- else:
484
- raise gr.Error(f"Model Load Error: Runtime error. Check logs. Error: {e}")
485
- except Exception as e: # Catch other potential errors during loading
486
  logging.error(
487
- f"Unexpected error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
488
- )
489
- raise gr.Error(
490
- f"Model Load Error: Failed to load checkpoint for unknown reason. Check logs. Error: {e}"
491
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
- except gr.Error as ge: # Catch Gradio-specific errors raised above
494
- raise ge # Re-raise to stop app launch correctly
495
- except Exception as e: # Catch any other unexpected errors during the whole process
496
- logging.error(f"Unexpected error during loading process: {e}", exc_info=True)
497
  raise gr.Error(
498
- f"Initialization Error: Unexpected failure during setup. Check logs. Error: {e}"
499
  )
500
 
501
 
502
- # --- Inference Function for Gradio ---
503
  def predict_iupac(smiles_string):
504
  """
505
  Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
506
- Handles input validation, canonicalization, translation, and output formatting.
507
  """
 
 
 
 
 
508
  global model, smiles_tokenizer, iupac_tokenizer, device, config
509
 
510
- # --- Check Initialization ---
511
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
512
- error_msg = "Error: Model or tokenizers not loaded. App initialization failed. Check Space logs."
513
  logging.error(error_msg)
514
- # Return the error directly in the output box
515
- return error_msg
516
 
517
- # --- Input Validation ---
518
  if not smiles_string or not smiles_string.strip():
519
- return "Error: Please enter a valid SMILES string."
 
520
 
521
- smiles_input_raw = smiles_string.strip()
522
 
523
- # --- Canonicalize SMILES (Optional but Recommended) ---
524
- smiles_input_canon = smiles_input_raw
525
- if Chem:
526
- try:
527
- mol = Chem.MolFromSmiles(smiles_input_raw)
528
- if mol:
529
- smiles_input_canon = Chem.MolToSmiles(mol, canonical=True)
530
- logging.info(f"Canonicalized SMILES: {smiles_input_raw} -> {smiles_input_canon}")
531
- else:
532
- # RDKit couldn't parse it, proceed with raw input but warn
533
- logging.warning(f"Could not parse SMILES '{smiles_input_raw}' with RDKit. Using raw input.")
534
- # Optionally return an error here if strict parsing is needed
535
- # return f"Error: Invalid SMILES string '{smiles_input_raw}' according to RDKit."
536
- except Exception as e:
537
- logging.error(f"Error during RDKit canonicalization for '{smiles_input_raw}': {e}", exc_info=True)
538
- # Proceed with raw input, maybe add note to output
539
- # return f"Error: RDKit processing failed: {e}" # Option to fail hard
540
-
541
- # --- Translation ---
542
  try:
 
543
  sos_idx = config["bos_token_id"]
544
  eos_idx = config["eos_token_id"]
545
  pad_idx = config["pad_token_id"]
546
- gen_max_len = config["max_len"] # Use max_len from config
547
 
548
- predicted_name = translate(
549
  model=model,
550
- src_sentence=smiles_input_canon, # Use canonicalized SMILES
551
  smiles_tokenizer=smiles_tokenizer,
552
  iupac_tokenizer=iupac_tokenizer,
553
  device=device,
554
- max_len_config=gen_max_len,
555
  sos_idx=sos_idx,
556
  eos_idx=eos_idx,
557
  pad_idx=pad_idx,
558
  )
559
- logging.info(f"SMILES: '{smiles_input_canon}', Prediction: '{predicted_name}'")
560
 
561
  # --- Format Output ---
562
- # Check if translate returned an error message
563
- if predicted_name.startswith("[") and predicted_name.endswith("]"):
564
- # Assume it's an error/warning message from translate()
565
  output_text = (
566
- f"Input SMILES: {smiles_input_canon}\n"
567
- f"(Raw Input: {smiles_input_raw})\n\n" # Show raw if canonicalization happened
568
- f"Prediction Failed: {predicted_name}"
569
- )
570
- elif not predicted_name: # Handle empty string case
571
- output_text = (
572
- f"Input SMILES: {smiles_input_canon}\n"
573
- f"(Raw Input: {smiles_input_raw})\n\n"
574
- f"Prediction: [No name generated]"
575
  )
 
 
576
  else:
577
  output_text = (
578
- f"Input SMILES: {smiles_input_canon}\n"
579
- f"(Raw Input: {smiles_input_raw})\n\n"
580
  f"Predicted IUPAC Name (Greedy Decode):\n"
581
  f"{predicted_name}"
582
  )
583
- # Remove the "(Raw Input...)" line if canonicalization didn't change the input
584
- if smiles_input_raw == smiles_input_canon:
585
- output_text = output_text.replace(f"(Raw Input: {smiles_input_raw})\n", "")
586
 
587
- return output_text.strip()
 
 
588
 
589
  except Exception as e:
590
- # Catch-all for unexpected errors during the prediction process
591
- logging.error(f"Unexpected error during prediction for '{smiles_input_canon}': {e}", exc_info=True)
592
- error_msg = f"Error: An unexpected error occurred during translation: {e}"
593
- return error_msg
594
 
595
 
596
  # --- Load Model on App Start ---
597
- # Wrap in try/except to allow Gradio UI to potentially display an error
598
- model_load_error = None
599
  try:
600
  load_model_and_tokenizers()
601
  except gr.Error as ge:
602
- logging.error(f"Gradio Initialization Error during load: {ge}")
603
- model_load_error = str(ge) # Store error message
604
  except Exception as e:
605
- logging.critical(f"CRITICAL error during initial model loading: {e}", exc_info=True)
606
- model_load_error = f"Critical Error: {e}. Check Space logs."
607
 
608
 
609
- # --- Create Gradio Interface ---
610
  title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
611
  description = f"""
612
  Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
613
  Translation uses **greedy decoding** (picks the most likely next word at each step).
 
614
  """
615
- if model_load_error:
616
- description += f"\n\n**WARNING: Failed to load model or components.**\nReason: {model_load_error}\nFunctionality will be limited."
617
- elif device:
618
- description += f"\n**Note:** Model loaded on **{str(device).upper()}**. Check `config.json` in the repo for model details."
619
- else:
620
- description += f"\n**Note:** Device information unavailable (loading might have failed)."
621
-
622
- # Use gr.Blocks for layout
623
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
624
  gr.Markdown(f"# {title}")
625
  gr.Markdown(description)
626
-
627
  with gr.Row():
628
- with gr.Column(scale=1): # Input column takes less space
629
  smiles_input = gr.Textbox(
630
  label="SMILES String",
631
- placeholder="Enter SMILES string (e.g., CCO or c1ccccc1)",
632
- lines=2, # Slightly more lines for longer SMILES
633
  )
634
- submit_btn = gr.Button("Translate", variant="primary")
635
-
636
- with gr.Column(scale=2): # Output column takes more space
 
 
637
  output_text = gr.Textbox(
638
- label="Result",
639
- lines=5, # More lines for formatted output
640
  show_copy_button=True,
641
- interactive=False, # Output box is not for user input
642
  )
643
-
644
- # Define examples
645
- gr.Examples(
646
- examples=[
647
- "CCO",
648
- "C1=CC=C(C=C1)C(=O)O", # Benzoic acid
649
- "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # Ibuprofen
650
- "INVALID_SMILES",
651
- "ClC(Cl)(Cl)C1=CC=C(C=C1)C(C2=CC=C(Cl)C=C2)C(Cl)(Cl)Cl", # DDT
652
- ],
653
- inputs=smiles_input,
654
- outputs=output_text,
655
- fn=predict_iupac, # Function to run for examples
656
- cache_examples=False, # Re-run examples each time if needed, True might speed up demo loading
657
- )
658
-
659
- # Connect the button click and input change events
660
- submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text, api_name="translate_smiles")
661
- # Optionally trigger on text change (can be slow/resource intensive)
662
- # smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
663
 
664
 
 
 
 
 
 
 
 
 
 
 
665
  # --- Launch the App ---
666
  if __name__ == "__main__":
667
- # Set share=True to get a public link (useful for testing)
668
- # Set debug=True for more detailed Gradio errors during development
669
- iface.launch(share=False, debug=False)
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+
5
  # import torch.nn.functional as F # No longer needed for greedy decode directly
6
  import pytorch_lightning as pl
7
  import os
 
10
  from tokenizers import Tokenizer
11
  from huggingface_hub import hf_hub_download
12
  import gc
13
+ from rdkit.Chem import CanonSmiles
14
+
 
 
 
 
 
15
 
16
  # --- Configuration ---
17
  MODEL_REPO_ID = (
 
30
 
31
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
32
  try:
 
33
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
34
+
35
  logging.info("Successfully imported from enhanced_trainer.py.")
36
  except ImportError as e:
37
  logging.error(
38
  f"Failed to import helper code from enhanced_trainer.py: {e}. "
39
  f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
40
  )
41
+ raise gr.Error(
42
+ f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
43
+ )
44
  except Exception as e:
45
  logging.error(
46
  f"An unexpected error occurred during helper code import: {e}", exc_info=True
47
  )
48
+ raise gr.Error(
49
+ f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
50
+ )
51
 
52
  # --- Global Variables (Load Model Once) ---
53
  model: pl.LightningModule | None = None
 
69
  ) -> torch.Tensor:
70
  """
71
  Performs greedy decoding using the LightningModule's model.
 
72
  """
 
 
 
 
 
 
 
 
 
 
73
  model.eval() # Ensure model is in evaluation mode
74
  transformer_model = model.model # Access the underlying Seq2SeqTransformer
75
 
76
  try:
77
  with torch.no_grad():
78
  # --- Encode Source ---
 
79
  memory = transformer_model.encode(
80
+ src, src_padding_mask
81
+ ) # [1, src_len, emb_size]
 
 
82
  memory = memory.to(device)
83
+ memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
 
 
84
 
85
  # --- Initialize Target Sequence ---
86
+ # Start with the SOS token
87
  ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
88
  sos_idx
89
+ ) # [1, 1]
90
 
91
  # --- Decoding Loop ---
92
+ for _ in range(max_len - 1): # Max length limit
 
 
 
 
 
 
93
  tgt_seq_len = ys.shape[1]
 
94
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
95
  device
96
+ ) # [curr_len, curr_len]
97
+ # No padding in target during generation yet
 
98
  tgt_padding_mask = torch.zeros(
99
  ys.shape, dtype=torch.bool, device=device
100
+ ) # [1, curr_len]
 
 
 
 
101
 
102
  # Decode one step
103
  decoder_output = transformer_model.decode(
104
+ tgt=ys,
105
+ memory=memory,
106
+ tgt_mask=tgt_mask,
107
+ tgt_padding_mask=tgt_padding_mask,
108
+ memory_key_padding_mask=memory_key_padding_mask,
109
+ ) # [1, curr_len, emb_size]
110
 
111
  # Get logits for the *next* token prediction
 
112
  next_token_logits = transformer_model.generator(
113
+ decoder_output[
114
+ :, -1, :
115
+ ] # Use output corresponding to the last input token
116
+ ) # [1, tgt_vocab_size]
117
 
118
  # Find the most likely next token (greedy choice)
119
+ # prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
120
+ # _, next_word_id_tensor = torch.max(prob, dim=1)
121
+ next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
122
  next_word_id = next_word_id_tensor.item()
123
 
124
+ # Append the chosen token to the sequence
125
+ ys = torch.cat(
126
+ [
127
+ ys,
128
+ torch.ones(1, 1, dtype=torch.long, device=device).fill_(
129
+ next_word_id
130
+ ),
131
+ ],
132
+ dim=1,
133
+ ) # [1, current_len + 1]
134
 
135
  # Stop if EOS token is generated
136
  if next_word_id == eos_idx:
137
  break
138
 
139
  # Return the generated sequence (excluding the initial SOS token)
140
+ return ys[:, 1:] # Shape [1, generated_len]
 
141
 
142
  except RuntimeError as e:
143
  logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
144
  if "CUDA out of memory" in str(e) and device.type == "cuda":
145
  gc.collect()
146
  torch.cuda.empty_cache()
147
+ return torch.empty(
148
+ (1, 0), dtype=torch.long, device=device
149
+ ) # Return empty tensor on error
150
  except Exception as e:
151
  logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
152
+ return torch.empty((1, 0), dtype=torch.long, device=device)
153
 
154
 
155
  # --- Translation Function (Using Greedy Decode) ---
 
159
  smiles_tokenizer: Tokenizer,
160
  iupac_tokenizer: Tokenizer,
161
  device: torch.device,
162
+ max_len: int,
163
  sos_idx: int,
164
  eos_idx: int,
165
  pad_idx: int,
166
+ ) -> str: # Returns a single string
167
  """
168
  Translates a single SMILES string using greedy decoding.
169
  """
 
 
 
170
  model.eval() # Ensure model is in eval mode
171
 
172
  # --- Tokenize Source ---
173
  try:
174
+ # Ensure tokenizer has truncation/padding configured if needed, or handle manually
175
+ smiles_tokenizer.enable_truncation(
176
+ max_length=max_len
177
+ ) # Use max_len for source truncation too
178
  src_encoded = smiles_tokenizer.encode(src_sentence)
179
  if not src_encoded or not src_encoded.ids:
180
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
181
+ return "[Encoding Error]"
182
+ # Use the truncated IDs directly
183
  src_ids = src_encoded.ids
 
 
 
 
 
184
  except Exception as e:
185
  logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
186
+ return "[Encoding Error]"
187
 
188
  # --- Prepare Input Tensor and Mask ---
189
+ src = (
190
+ torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
191
+ ) # [1, src_len]
192
+ # Create padding mask (True where it's a pad token, should be all False here unless tokenizer pads)
193
+ src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
194
 
195
  # --- Perform Greedy Decoding ---
196
+ # Calls the greedy_decode function defined *above in this file*
197
+ # Note: max_len for generation should come from config if it dictates output length
198
+ generation_max_len = config.get(
199
+ "max_len", 256
200
+ ) # Use config max_len for output limit
201
+ tgt_tokens_tensor = greedy_decode(
202
+ model=model,
203
+ src=src,
204
+ src_padding_mask=src_padding_mask,
205
+ max_len=generation_max_len, # Use generation limit
206
+ sos_idx=sos_idx,
207
+ eos_idx=eos_idx,
208
+ # pad_idx=pad_idx, # Not needed by greedy_decode internal loop
209
+ device=device,
210
+ ) # Returns a single tensor [1, generated_len]
211
 
212
  # --- Decode Generated Tokens ---
213
  if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
214
+ logging.warning(
215
+ f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
216
+ )
217
+ return "[Decoding Error - Empty Output]"
 
 
 
 
 
218
 
219
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
220
  try:
221
  # Decode using the target tokenizer, skipping special tokens
222
  translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
223
+ return translation
224
  except Exception as e:
225
  logging.error(
226
  f"Error decoding target tokens {tgt_tokens}: {e}",
227
  exc_info=True,
228
  )
229
+ return "[Decoding Error]"
230
 
231
 
232
+ # --- Model/Tokenizer Loading Function (Unchanged) ---
233
  def load_model_and_tokenizers():
234
  """Loads tokenizers, config, and model from Hugging Face Hub."""
235
+ global model, smiles_tokenizer, iupac_tokenizer, device, config
236
  if model is not None: # Already loaded
237
  logging.info("Model and tokenizers already loaded.")
238
  return
239
 
240
  logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
 
 
 
 
 
 
 
241
  try:
242
  # Determine device
243
  if torch.cuda.is_available():
244
+ logging.warning(
245
+ "CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
246
+ )
247
+ device = torch.device("cpu") # Uncomment if GPU is intended
248
+ # device = torch.device("cuda")
249
+ # logging.info("CUDA available, using GPU.")
250
  else:
251
  device = torch.device("cpu")
252
  logging.info("CUDA not available, using CPU.")
 
254
  # Download files
255
  logging.info("Downloading files from Hugging Face Hub...")
256
  try:
 
257
  cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
258
+ os.makedirs(cache_dir, exist_ok=True)
259
  logging.info(f"Using cache directory: {cache_dir}")
260
 
 
261
  checkpoint_path = hf_hub_download(
262
+ repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
263
  )
264
  smiles_tokenizer_path = hf_hub_download(
265
+ repo_id=MODEL_REPO_ID,
266
+ filename=SMILES_TOKENIZER_FILENAME,
267
+ cache_dir=cache_dir,
268
  )
269
  iupac_tokenizer_path = hf_hub_download(
270
+ repo_id=MODEL_REPO_ID,
271
+ filename=IUPAC_TOKENIZER_FILENAME,
272
+ cache_dir=cache_dir,
273
  )
274
  config_path = hf_hub_download(
275
+ repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
276
  )
277
+ logging.info("Files downloaded successfully.")
278
  except Exception as e:
279
  logging.error(
280
  f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
 
292
  logging.info("Configuration loaded.")
293
  # --- Validate essential config keys ---
294
  required_keys = [
295
+ "src_vocab_size", # Use the key saved in config
296
+ "tgt_vocab_size", # Use the key saved in config
297
  "emb_size",
298
  "nhead",
299
  "ffn_hid_dim",
300
  "num_encoder_layers",
301
  "num_decoder_layers",
302
  "dropout",
303
+ "max_len",
304
  "pad_token_id",
305
  "bos_token_id",
306
  "eos_token_id",
307
  ]
308
+ # Remap if needed (example shown, adjust if your keys differ)
309
+ config_key_mapping = {
310
+ "src_vocab_size": config.get(
311
+ "src_vocab_size", config.get("src_vocab_size")
312
+ ),
313
+ "tgt_vocab_size": config.get(
314
+ "tgt_vocab_size", config.get("tgt_vocab_size")
315
+ ),
316
+ # Add other mappings if necessary
317
+ }
318
+ config.update(config_key_mapping)
 
 
 
 
 
319
 
320
  missing_keys = [key for key in required_keys if config.get(key) is None]
321
  if missing_keys:
 
325
  )
326
 
327
  logging.info(
328
+ f"Using config values: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
329
  f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
330
  f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
331
  )
 
350
  try:
351
  smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
352
  iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
353
+ logging.info("Tokenizers loaded.")
354
+ # --- Optional: Validate Tokenizer Special Tokens Against Config ---
355
+ # (Keep validation as before, it's still useful)
356
  pad_token = "<pad>"
357
  sos_token = "<sos>"
358
  eos_token = "<eos>"
359
+ unk_token = "<unk>"
360
  issues = []
361
+ # ... (keep the validation checks as in the original code) ...
362
+ if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
363
+ issues.append(f"SMILES PAD ID mismatch")
364
+ if smiles_tokenizer.token_to_id(unk_token) is None:
365
+ issues.append("SMILES UNK token not found")
366
+ if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
367
+ issues.append(f"IUPAC PAD ID mismatch")
368
+ if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
369
+ issues.append(f"IUPAC SOS ID mismatch")
370
+ if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
371
+ issues.append(f"IUPAC EOS ID mismatch")
372
+ if iupac_tokenizer.token_to_id(unk_token) is None:
373
+ issues.append("IUPAC UNK token not found")
 
 
 
 
 
 
 
 
 
 
374
  if issues:
375
+ logging.warning("Tokenizer validation issues: " + "; ".join(issues))
 
 
 
 
376
 
377
  except Exception as e:
378
  logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
379
  raise gr.Error(
380
+ f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
381
  )
382
 
383
  # Load model
384
  logging.info("Loading model from checkpoint...")
385
  try:
386
+ # Use the vocab sizes and hparams from the loaded config
 
 
 
 
 
 
 
 
 
387
  model = SmilesIupacLitModule.load_from_checkpoint(
388
  checkpoint_path,
389
+ **config, # Pass loaded config directly as keyword args
390
  map_location=device,
391
+ devices=1,
392
+ strict=True, # Start strict, set to False if encountering key errors
 
 
393
  )
394
 
 
395
  model.to(device)
396
  model.eval()
397
+ model.freeze()
398
  logging.info(
399
  f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
400
  )
 
402
  except FileNotFoundError:
403
  logging.error(f"Checkpoint file not found: {checkpoint_path}")
404
  raise gr.Error(
405
+ f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
406
  )
407
+ except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  logging.error(
409
+ f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
 
 
 
410
  )
411
+ if "size mismatch" in str(e):
412
+ error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
413
+ logging.error(error_detail)
414
+ raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
415
+ elif "memory" in str(e).lower():
416
+ logging.warning("Potential OOM error during model loading.")
417
+ gc.collect()
418
+ torch.cuda.empty_cache() if device.type == "cuda" else None
419
+ raise gr.Error(
420
+ f"Model Error: OOM loading model. Check Space resources. Error: {e}"
421
+ )
422
+ else:
423
+ raise gr.Error(
424
+ f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
425
+ )
426
 
427
+ except gr.Error:
428
+ raise
429
+ except Exception as e:
430
+ logging.error(f"Unexpected error during loading: {e}", exc_info=True)
431
  raise gr.Error(
432
+ f"Initialization Error: Unexpected error. Check logs. Error: {e}"
433
  )
434
 
435
 
436
+ # --- Inference Function for Gradio (Simplified) ---
437
  def predict_iupac(smiles_string):
438
  """
439
  Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
 
440
  """
441
+ try:
442
+ smiles_string = CanonSmiles(smiles_string)
443
+ except Exception as e:
444
+ logging.error(f"Error during SMILES canonicalization: {e}", exc_info=True)
445
+ return f"Error: {e}"
446
  global model, smiles_tokenizer, iupac_tokenizer, device, config
447
 
 
448
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
449
+ error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
450
  logging.error(error_msg)
451
+ return f"Error: {error_msg}" # Return single error string
 
452
 
 
453
  if not smiles_string or not smiles_string.strip():
454
+ error_msg = "Error: Please enter a valid SMILES string."
455
+ return f"Error: {error_msg}" # Return single error string
456
 
457
+ smiles_input = smiles_string.strip()
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  try:
460
+ # --- Call the core translation logic (greedy) ---
461
  sos_idx = config["bos_token_id"]
462
  eos_idx = config["eos_token_id"]
463
  pad_idx = config["pad_token_id"]
464
+ gen_max_len = config["max_len"]
465
 
466
+ predicted_name = translate( # Returns a single string now
467
  model=model,
468
+ src_sentence=smiles_input,
469
  smiles_tokenizer=smiles_tokenizer,
470
  iupac_tokenizer=iupac_tokenizer,
471
  device=device,
472
+ max_len=gen_max_len,
473
  sos_idx=sos_idx,
474
  eos_idx=eos_idx,
475
  pad_idx=pad_idx,
476
  )
477
+ logging.info(f"Prediction returned: {predicted_name}")
478
 
479
  # --- Format Output ---
480
+ if "[Error]" in predicted_name: # Check for error messages from translate
 
 
481
  output_text = (
482
+ f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
 
 
 
 
 
 
 
 
483
  )
484
+ elif not predicted_name:
485
+ output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
486
  else:
487
  output_text = (
488
+ f"Input SMILES: {smiles_input}\n\n"
 
489
  f"Predicted IUPAC Name (Greedy Decode):\n"
490
  f"{predicted_name}"
491
  )
492
+ return output_text
 
 
493
 
494
+ except RuntimeError as e:
495
+ logging.error(f"Runtime error during translation: {e}", exc_info=True)
496
+ return f"Error: {error_msg}" # Return single error string
497
 
498
  except Exception as e:
499
+ logging.error(f"Unexpected error during translation: {e}", exc_info=True)
500
+ error_msg = f"Unexpected Error during translation: {e}"
501
+ return f"Error: {error_msg}" # Return single error string
 
502
 
503
 
504
  # --- Load Model on App Start ---
 
 
505
  try:
506
  load_model_and_tokenizers()
507
  except gr.Error as ge:
508
+ logging.error(f"Gradio Initialization Error: {ge}")
509
+ pass # Allow Gradio to potentially start with an error message
510
  except Exception as e:
511
+ logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
512
+ # Optionally raise gr.Error here too
513
 
514
 
515
+ # --- Create Gradio Interface (Simplified) ---
516
  title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
517
  description = f"""
518
  Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
519
  Translation uses **greedy decoding** (picks the most likely next word at each step).
520
+ **Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
521
  """
522
+
523
+
524
+
525
+ # Replace your Interface code with this:
 
 
 
 
526
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
527
  gr.Markdown(f"# {title}")
528
  gr.Markdown(description)
529
+
530
  with gr.Row():
531
+ with gr.Column():
532
  smiles_input = gr.Textbox(
533
  label="SMILES String",
534
+ placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
535
+ lines=1,
536
  )
537
+
538
+ # Add a button for manual triggering
539
+ submit_btn = gr.Button("Translate")
540
+
541
+ with gr.Column():
542
  output_text = gr.Textbox(
543
+ label="Predicted IUPAC Name",
544
+ lines=3,
545
  show_copy_button=True,
 
546
  )
547
+
548
+ # Connect the events properly
549
+ submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
550
+
551
+ # If you still want change event (auto-translation as user types)
552
+ smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
553
+ # Output component
554
+ output_text = gr.Textbox(
555
+ label="Predicted IUPAC Name",
556
+ lines=3,
557
+ show_copy_button=True, # Reduced lines slightly
558
+ )
 
 
 
 
 
 
 
 
559
 
560
 
561
+ # Create the interface instance
562
+ iface = gr.Interface(
563
+ fn=predict_iupac, # The function to call
564
+ inputs=smiles_input, # Input component directly
565
+ outputs=output_text, # Output component
566
+ title=title,
567
+ description=description,
568
+ allow_flagging="never",
569
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
570
+ )
571
  # --- Launch the App ---
572
  if __name__ == "__main__":
573
+ iface.launch()