AdrianM0 commited on
Commit
e6c42e6
·
verified ·
1 Parent(s): 21ea065

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +110 -56
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # app.py
2
  import gradio as gr
3
  import torch
 
4
  # import torch.nn.functional as F # No longer needed for greedy decode directly
5
  import pytorch_lightning as pl
6
  import os
@@ -28,6 +29,7 @@ logging.basicConfig(
28
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
29
  try:
30
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
 
31
  logging.info("Successfully imported from enhanced_trainer.py.")
32
  except ImportError as e:
33
  logging.error(
@@ -80,10 +82,12 @@ def greedy_decode(
80
 
81
  # --- Initialize Target Sequence ---
82
  # Start with the SOS token
83
- ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) # [1, 1]
 
 
84
 
85
  # --- Decoding Loop ---
86
- for _ in range(max_len - 1): # Max length limit
87
  tgt_seq_len = ys.shape[1]
88
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
89
  device
@@ -104,34 +108,43 @@ def greedy_decode(
104
 
105
  # Get logits for the *next* token prediction
106
  next_token_logits = transformer_model.generator(
107
- decoder_output[:, -1, :] # Use output corresponding to the last input token
 
 
108
  ) # [1, tgt_vocab_size]
109
 
110
  # Find the most likely next token (greedy choice)
111
  # prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
112
  # _, next_word_id_tensor = torch.max(prob, dim=1)
113
- next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
114
  next_word_id = next_word_id_tensor.item()
115
 
116
  # Append the chosen token to the sequence
117
  ys = torch.cat(
118
- [ys, torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id)],
119
- dim=1
120
- ) # [1, current_len + 1]
 
 
 
 
 
121
 
122
  # Stop if EOS token is generated
123
  if next_word_id == eos_idx:
124
  break
125
 
126
  # Return the generated sequence (excluding the initial SOS token)
127
- return ys[:, 1:] # Shape [1, generated_len]
128
 
129
  except RuntimeError as e:
130
  logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
131
  if "CUDA out of memory" in str(e) and device.type == "cuda":
132
  gc.collect()
133
  torch.cuda.empty_cache()
134
- return torch.empty((1, 0), dtype=torch.long, device=device) # Return empty tensor on error
 
 
135
  except Exception as e:
136
  logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
137
  return torch.empty((1, 0), dtype=torch.long, device=device)
@@ -148,7 +161,7 @@ def translate(
148
  sos_idx: int,
149
  eos_idx: int,
150
  pad_idx: int,
151
- ) -> str: # Returns a single string
152
  """
153
  Translates a single SMILES string using greedy decoding.
154
  """
@@ -157,7 +170,9 @@ def translate(
157
  # --- Tokenize Source ---
158
  try:
159
  # Ensure tokenizer has truncation/padding configured if needed, or handle manually
160
- smiles_tokenizer.enable_truncation(max_length=max_len) # Use max_len for source truncation too
 
 
161
  src_encoded = smiles_tokenizer.encode(src_sentence)
162
  if not src_encoded or not src_encoded.ids:
163
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
@@ -194,15 +209,15 @@ def translate(
194
 
195
  # --- Decode Generated Tokens ---
196
  if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
197
- logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}")
 
 
198
  return "[Decoding Error - Empty Output]"
199
 
200
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
201
  try:
202
  # Decode using the target tokenizer, skipping special tokens
203
- translation = iupac_tokenizer.decode(
204
- tgt_tokens, skip_special_tokens=True
205
- )
206
  return translation
207
  except Exception as e:
208
  logging.error(
@@ -275,23 +290,34 @@ def load_model_and_tokenizers():
275
  logging.info("Configuration loaded.")
276
  # --- Validate essential config keys ---
277
  required_keys = [
278
- "src_vocab_size", # Use the key saved in config
279
- "tgt_vocab_size", # Use the key saved in config
280
- "emb_size", "nhead", "ffn_hid_dim", "num_encoder_layers",
281
- "num_decoder_layers", "dropout", "max_len",
282
- "pad_token_id", "bos_token_id", "eos_token_id",
 
 
 
 
 
 
 
283
  ]
284
  # Remap if needed (example shown, adjust if your keys differ)
285
  config_key_mapping = {
286
- "src_vocab_size": config.get("src_vocab_size", config.get("actual_src_vocab_size")),
287
- "tgt_vocab_size": config.get("tgt_vocab_size", config.get("actual_tgt_vocab_size")),
 
 
 
 
288
  # Add other mappings if necessary
289
  }
290
  config.update(config_key_mapping)
291
 
292
  missing_keys = [key for key in required_keys if config.get(key) is None]
293
  if missing_keys:
294
- raise ValueError(
295
  f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
296
  f"Ensure these were saved in the hyperparameters during training."
297
  )
@@ -307,7 +333,9 @@ def load_model_and_tokenizers():
307
  raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
308
  except json.JSONDecodeError as e:
309
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
310
- raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Error: {e}")
 
 
311
  except ValueError as e:
312
  logging.error(f"Config validation error: {e}")
313
  raise gr.Error(f"Config Error: {e}")
@@ -329,17 +357,26 @@ def load_model_and_tokenizers():
329
  unk_token = "<unk>"
330
  issues = []
331
  # ... (keep the validation checks as in the original code) ...
332
- if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]: issues.append(f"SMILES PAD ID mismatch")
333
- if smiles_tokenizer.token_to_id(unk_token) is None: issues.append("SMILES UNK token not found")
334
- if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]: issues.append(f"IUPAC PAD ID mismatch")
335
- if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]: issues.append(f"IUPAC SOS ID mismatch")
336
- if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]: issues.append(f"IUPAC EOS ID mismatch")
337
- if iupac_tokenizer.token_to_id(unk_token) is None: issues.append("IUPAC UNK token not found")
338
- if issues: logging.warning("Tokenizer validation issues: " + "; ".join(issues))
 
 
 
 
 
 
 
339
 
340
  except Exception as e:
341
  logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
342
- raise gr.Error(f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}")
 
 
343
 
344
  # Load model
345
  logging.info("Loading model from checkpoint...")
@@ -352,9 +389,9 @@ def load_model_and_tokenizers():
352
  tgt_vocab_size=config["tgt_vocab_size"],
353
  # Pass the whole config dict, load_from_checkpoint will pick what it needs
354
  # if the keys match the __init__ args or are in hparams
355
- **config, # Pass loaded config directly as keyword args
356
  map_location=device,
357
- strict=True, # Start strict, set to False if encountering key errors
358
  )
359
 
360
  model.to(device)
@@ -366,24 +403,36 @@ def load_model_and_tokenizers():
366
 
367
  except FileNotFoundError:
368
  logging.error(f"Checkpoint file not found: {checkpoint_path}")
369
- raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
 
 
370
  except Exception as e:
371
- logging.error(f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
 
 
372
  if "size mismatch" in str(e):
373
- error_detail = (f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint.")
374
  logging.error(error_detail)
375
  raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
376
  elif "memory" in str(e).lower():
377
  logging.warning("Potential OOM error during model loading.")
378
- gc.collect(); torch.cuda.empty_cache() if device.type == "cuda" else None
379
- raise gr.Error(f"Model Error: OOM loading model. Check Space resources. Error: {e}")
 
 
 
380
  else:
381
- raise gr.Error(f"Model Error: Failed to load checkpoint. Check logs. Error: {e}")
 
 
382
 
383
- except gr.Error: raise
 
384
  except Exception as e:
385
  logging.error(f"Unexpected error during loading: {e}", exc_info=True)
386
- raise gr.Error(f"Initialization Error: Unexpected error. Check logs. Error: {e}")
 
 
387
 
388
 
389
  # --- Inference Function for Gradio (Simplified) ---
@@ -396,11 +445,11 @@ def predict_iupac(smiles_string):
396
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
397
  error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
398
  logging.error(error_msg)
399
- return f"Error: {error_msg}" # Return single error string
400
 
401
  if not smiles_string or not smiles_string.strip():
402
  error_msg = "Error: Please enter a valid SMILES string."
403
- return f"Error: {error_msg}" # Return single error string
404
 
405
  smiles_input = smiles_string.strip()
406
 
@@ -411,7 +460,7 @@ def predict_iupac(smiles_string):
411
  pad_idx = config["pad_token_id"]
412
  gen_max_len = config["max_len"]
413
 
414
- predicted_name = translate( # Returns a single string now
415
  model=model,
416
  src_sentence=smiles_input,
417
  smiles_tokenizer=smiles_tokenizer,
@@ -425,10 +474,12 @@ def predict_iupac(smiles_string):
425
  logging.info(f"Prediction returned: {predicted_name}")
426
 
427
  # --- Format Output ---
428
- if "[Error]" in predicted_name: # Check for error messages from translate
429
- output_text = f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
 
 
430
  elif not predicted_name:
431
- output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
432
  else:
433
  output_text = (
434
  f"Input SMILES: {smiles_input}\n\n"
@@ -439,12 +490,12 @@ def predict_iupac(smiles_string):
439
 
440
  except RuntimeError as e:
441
  logging.error(f"Runtime error during translation: {e}", exc_info=True)
442
- return f"Error: {error_msg}" # Return single error string
443
 
444
  except Exception as e:
445
  logging.error(f"Unexpected error during translation: {e}", exc_info=True)
446
  error_msg = f"Unexpected Error during translation: {e}"
447
- return f"Error: {error_msg}" # Return single error string
448
 
449
 
450
  # --- Load Model on App Start ---
@@ -452,7 +503,7 @@ try:
452
  load_model_and_tokenizers()
453
  except gr.Error as ge:
454
  logging.error(f"Gradio Initialization Error: {ge}")
455
- pass # Allow Gradio to potentially start with an error message
456
  except Exception as e:
457
  logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
458
  # Optionally raise gr.Error here too
@@ -463,7 +514,7 @@ title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
463
  description = f"""
464
  Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
465
  Translation uses **greedy decoding** (picks the most likely next word at each step).
466
- **Note:** Model loaded on **{str(device).upper() if device else 'N/A'}**. Performance may vary. Check `config.json` in the repo for model details.
467
  """
468
 
469
  # Define examples - remove beam search parameters
@@ -481,16 +532,19 @@ smiles_input = gr.Textbox(
481
  placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
482
  lines=1,
483
  )
484
-
 
485
  # Output component
486
  output_text = gr.Textbox(
487
- label="Predicted IUPAC Name", lines=3, show_copy_button=True # Reduced lines slightly
 
 
488
  )
489
 
490
  # Create the interface instance
491
  iface = gr.Interface(
492
  fn=predict_iupac, # The function to call
493
- inputs=smiles_input, # Single input component
494
  outputs=output_text, # Output component
495
  title=title,
496
  description=description,
@@ -505,4 +559,4 @@ iface = gr.Interface(
505
 
506
  # --- Launch the App ---
507
  if __name__ == "__main__":
508
- iface.launch()
 
1
  # app.py
2
  import gradio as gr
3
  import torch
4
+
5
  # import torch.nn.functional as F # No longer needed for greedy decode directly
6
  import pytorch_lightning as pl
7
  import os
 
29
  # --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
30
  try:
31
  from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
32
+
33
  logging.info("Successfully imported from enhanced_trainer.py.")
34
  except ImportError as e:
35
  logging.error(
 
82
 
83
  # --- Initialize Target Sequence ---
84
  # Start with the SOS token
85
+ ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
86
+ sos_idx
87
+ ) # [1, 1]
88
 
89
  # --- Decoding Loop ---
90
+ for _ in range(max_len - 1): # Max length limit
91
  tgt_seq_len = ys.shape[1]
92
  tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
93
  device
 
108
 
109
  # Get logits for the *next* token prediction
110
  next_token_logits = transformer_model.generator(
111
+ decoder_output[
112
+ :, -1, :
113
+ ] # Use output corresponding to the last input token
114
  ) # [1, tgt_vocab_size]
115
 
116
  # Find the most likely next token (greedy choice)
117
  # prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
118
  # _, next_word_id_tensor = torch.max(prob, dim=1)
119
+ next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
120
  next_word_id = next_word_id_tensor.item()
121
 
122
  # Append the chosen token to the sequence
123
  ys = torch.cat(
124
+ [
125
+ ys,
126
+ torch.ones(1, 1, dtype=torch.long, device=device).fill_(
127
+ next_word_id
128
+ ),
129
+ ],
130
+ dim=1,
131
+ ) # [1, current_len + 1]
132
 
133
  # Stop if EOS token is generated
134
  if next_word_id == eos_idx:
135
  break
136
 
137
  # Return the generated sequence (excluding the initial SOS token)
138
+ return ys[:, 1:] # Shape [1, generated_len]
139
 
140
  except RuntimeError as e:
141
  logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
142
  if "CUDA out of memory" in str(e) and device.type == "cuda":
143
  gc.collect()
144
  torch.cuda.empty_cache()
145
+ return torch.empty(
146
+ (1, 0), dtype=torch.long, device=device
147
+ ) # Return empty tensor on error
148
  except Exception as e:
149
  logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
150
  return torch.empty((1, 0), dtype=torch.long, device=device)
 
161
  sos_idx: int,
162
  eos_idx: int,
163
  pad_idx: int,
164
+ ) -> str: # Returns a single string
165
  """
166
  Translates a single SMILES string using greedy decoding.
167
  """
 
170
  # --- Tokenize Source ---
171
  try:
172
  # Ensure tokenizer has truncation/padding configured if needed, or handle manually
173
+ smiles_tokenizer.enable_truncation(
174
+ max_length=max_len
175
+ ) # Use max_len for source truncation too
176
  src_encoded = smiles_tokenizer.encode(src_sentence)
177
  if not src_encoded or not src_encoded.ids:
178
  logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
 
209
 
210
  # --- Decode Generated Tokens ---
211
  if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
212
+ logging.warning(
213
+ f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
214
+ )
215
  return "[Decoding Error - Empty Output]"
216
 
217
  tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
218
  try:
219
  # Decode using the target tokenizer, skipping special tokens
220
+ translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
 
 
221
  return translation
222
  except Exception as e:
223
  logging.error(
 
290
  logging.info("Configuration loaded.")
291
  # --- Validate essential config keys ---
292
  required_keys = [
293
+ "src_vocab_size", # Use the key saved in config
294
+ "tgt_vocab_size", # Use the key saved in config
295
+ "emb_size",
296
+ "nhead",
297
+ "ffn_hid_dim",
298
+ "num_encoder_layers",
299
+ "num_decoder_layers",
300
+ "dropout",
301
+ "max_len",
302
+ "pad_token_id",
303
+ "bos_token_id",
304
+ "eos_token_id",
305
  ]
306
  # Remap if needed (example shown, adjust if your keys differ)
307
  config_key_mapping = {
308
+ "src_vocab_size": config.get(
309
+ "src_vocab_size", config.get("actual_src_vocab_size")
310
+ ),
311
+ "tgt_vocab_size": config.get(
312
+ "tgt_vocab_size", config.get("actual_tgt_vocab_size")
313
+ ),
314
  # Add other mappings if necessary
315
  }
316
  config.update(config_key_mapping)
317
 
318
  missing_keys = [key for key in required_keys if config.get(key) is None]
319
  if missing_keys:
320
+ raise ValueError(
321
  f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
322
  f"Ensure these were saved in the hyperparameters during training."
323
  )
 
333
  raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
334
  except json.JSONDecodeError as e:
335
  logging.error(f"Error decoding JSON from config file {config_path}: {e}")
336
+ raise gr.Error(
337
+ f"Config Error: Could not parse '{CONFIG_FILENAME}'. Error: {e}"
338
+ )
339
  except ValueError as e:
340
  logging.error(f"Config validation error: {e}")
341
  raise gr.Error(f"Config Error: {e}")
 
357
  unk_token = "<unk>"
358
  issues = []
359
  # ... (keep the validation checks as in the original code) ...
360
+ if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
361
+ issues.append(f"SMILES PAD ID mismatch")
362
+ if smiles_tokenizer.token_to_id(unk_token) is None:
363
+ issues.append("SMILES UNK token not found")
364
+ if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
365
+ issues.append(f"IUPAC PAD ID mismatch")
366
+ if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
367
+ issues.append(f"IUPAC SOS ID mismatch")
368
+ if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
369
+ issues.append(f"IUPAC EOS ID mismatch")
370
+ if iupac_tokenizer.token_to_id(unk_token) is None:
371
+ issues.append("IUPAC UNK token not found")
372
+ if issues:
373
+ logging.warning("Tokenizer validation issues: " + "; ".join(issues))
374
 
375
  except Exception as e:
376
  logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
377
+ raise gr.Error(
378
+ f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
379
+ )
380
 
381
  # Load model
382
  logging.info("Loading model from checkpoint...")
 
389
  tgt_vocab_size=config["tgt_vocab_size"],
390
  # Pass the whole config dict, load_from_checkpoint will pick what it needs
391
  # if the keys match the __init__ args or are in hparams
392
+ **config, # Pass loaded config directly as keyword args
393
  map_location=device,
394
+ strict=True, # Start strict, set to False if encountering key errors
395
  )
396
 
397
  model.to(device)
 
403
 
404
  except FileNotFoundError:
405
  logging.error(f"Checkpoint file not found: {checkpoint_path}")
406
+ raise gr.Error(
407
+ f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
408
+ )
409
  except Exception as e:
410
+ logging.error(
411
+ f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
412
+ )
413
  if "size mismatch" in str(e):
414
+ error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
415
  logging.error(error_detail)
416
  raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
417
  elif "memory" in str(e).lower():
418
  logging.warning("Potential OOM error during model loading.")
419
+ gc.collect()
420
+ torch.cuda.empty_cache() if device.type == "cuda" else None
421
+ raise gr.Error(
422
+ f"Model Error: OOM loading model. Check Space resources. Error: {e}"
423
+ )
424
  else:
425
+ raise gr.Error(
426
+ f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
427
+ )
428
 
429
+ except gr.Error:
430
+ raise
431
  except Exception as e:
432
  logging.error(f"Unexpected error during loading: {e}", exc_info=True)
433
+ raise gr.Error(
434
+ f"Initialization Error: Unexpected error. Check logs. Error: {e}"
435
+ )
436
 
437
 
438
  # --- Inference Function for Gradio (Simplified) ---
 
445
  if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
446
  error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
447
  logging.error(error_msg)
448
+ return f"Error: {error_msg}" # Return single error string
449
 
450
  if not smiles_string or not smiles_string.strip():
451
  error_msg = "Error: Please enter a valid SMILES string."
452
+ return f"Error: {error_msg}" # Return single error string
453
 
454
  smiles_input = smiles_string.strip()
455
 
 
460
  pad_idx = config["pad_token_id"]
461
  gen_max_len = config["max_len"]
462
 
463
+ predicted_name = translate( # Returns a single string now
464
  model=model,
465
  src_sentence=smiles_input,
466
  smiles_tokenizer=smiles_tokenizer,
 
474
  logging.info(f"Prediction returned: {predicted_name}")
475
 
476
  # --- Format Output ---
477
+ if "[Error]" in predicted_name: # Check for error messages from translate
478
+ output_text = (
479
+ f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
480
+ )
481
  elif not predicted_name:
482
+ output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
483
  else:
484
  output_text = (
485
  f"Input SMILES: {smiles_input}\n\n"
 
490
 
491
  except RuntimeError as e:
492
  logging.error(f"Runtime error during translation: {e}", exc_info=True)
493
+ return f"Error: {error_msg}" # Return single error string
494
 
495
  except Exception as e:
496
  logging.error(f"Unexpected error during translation: {e}", exc_info=True)
497
  error_msg = f"Unexpected Error during translation: {e}"
498
+ return f"Error: {error_msg}" # Return single error string
499
 
500
 
501
  # --- Load Model on App Start ---
 
503
  load_model_and_tokenizers()
504
  except gr.Error as ge:
505
  logging.error(f"Gradio Initialization Error: {ge}")
506
+ pass # Allow Gradio to potentially start with an error message
507
  except Exception as e:
508
  logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
509
  # Optionally raise gr.Error here too
 
514
  description = f"""
515
  Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
516
  Translation uses **greedy decoding** (picks the most likely next word at each step).
517
+ **Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
518
  """
519
 
520
  # Define examples - remove beam search parameters
 
532
  placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
533
  lines=1,
534
  )
535
+ from rdkit.Chem import CanonSmiles
536
+ smiles_input = CanonSmiles(smiles_input)
537
  # Output component
538
  output_text = gr.Textbox(
539
+ label="Predicted IUPAC Name",
540
+ lines=3,
541
+ show_copy_button=True, # Reduced lines slightly
542
  )
543
 
544
  # Create the interface instance
545
  iface = gr.Interface(
546
  fn=predict_iupac, # The function to call
547
+ inputs=smiles_input, # Single input component
548
  outputs=output_text, # Output component
549
  title=title,
550
  description=description,
 
559
 
560
  # --- Launch the App ---
561
  if __name__ == "__main__":
562
+ iface.launch()