Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
import torch
|
|
|
4 |
# import torch.nn.functional as F # No longer needed for greedy decode directly
|
5 |
import pytorch_lightning as pl
|
6 |
import os
|
@@ -28,6 +29,7 @@ logging.basicConfig(
|
|
28 |
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
|
29 |
try:
|
30 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
|
|
31 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
32 |
except ImportError as e:
|
33 |
logging.error(
|
@@ -80,10 +82,12 @@ def greedy_decode(
|
|
80 |
|
81 |
# --- Initialize Target Sequence ---
|
82 |
# Start with the SOS token
|
83 |
-
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
|
|
|
|
84 |
|
85 |
# --- Decoding Loop ---
|
86 |
-
for _ in range(max_len - 1):
|
87 |
tgt_seq_len = ys.shape[1]
|
88 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
89 |
device
|
@@ -104,34 +108,43 @@ def greedy_decode(
|
|
104 |
|
105 |
# Get logits for the *next* token prediction
|
106 |
next_token_logits = transformer_model.generator(
|
107 |
-
decoder_output[
|
|
|
|
|
108 |
) # [1, tgt_vocab_size]
|
109 |
|
110 |
# Find the most likely next token (greedy choice)
|
111 |
# prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
|
112 |
# _, next_word_id_tensor = torch.max(prob, dim=1)
|
113 |
-
next_word_id_tensor = torch.argmax(next_token_logits, dim=1)
|
114 |
next_word_id = next_word_id_tensor.item()
|
115 |
|
116 |
# Append the chosen token to the sequence
|
117 |
ys = torch.cat(
|
118 |
-
[
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# Stop if EOS token is generated
|
123 |
if next_word_id == eos_idx:
|
124 |
break
|
125 |
|
126 |
# Return the generated sequence (excluding the initial SOS token)
|
127 |
-
return ys[:, 1:]
|
128 |
|
129 |
except RuntimeError as e:
|
130 |
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
|
131 |
if "CUDA out of memory" in str(e) and device.type == "cuda":
|
132 |
gc.collect()
|
133 |
torch.cuda.empty_cache()
|
134 |
-
return torch.empty(
|
|
|
|
|
135 |
except Exception as e:
|
136 |
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
|
137 |
return torch.empty((1, 0), dtype=torch.long, device=device)
|
@@ -148,7 +161,7 @@ def translate(
|
|
148 |
sos_idx: int,
|
149 |
eos_idx: int,
|
150 |
pad_idx: int,
|
151 |
-
) -> str:
|
152 |
"""
|
153 |
Translates a single SMILES string using greedy decoding.
|
154 |
"""
|
@@ -157,7 +170,9 @@ def translate(
|
|
157 |
# --- Tokenize Source ---
|
158 |
try:
|
159 |
# Ensure tokenizer has truncation/padding configured if needed, or handle manually
|
160 |
-
smiles_tokenizer.enable_truncation(
|
|
|
|
|
161 |
src_encoded = smiles_tokenizer.encode(src_sentence)
|
162 |
if not src_encoded or not src_encoded.ids:
|
163 |
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
@@ -194,15 +209,15 @@ def translate(
|
|
194 |
|
195 |
# --- Decode Generated Tokens ---
|
196 |
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
|
197 |
-
logging.warning(
|
|
|
|
|
198 |
return "[Decoding Error - Empty Output]"
|
199 |
|
200 |
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
201 |
try:
|
202 |
# Decode using the target tokenizer, skipping special tokens
|
203 |
-
translation = iupac_tokenizer.decode(
|
204 |
-
tgt_tokens, skip_special_tokens=True
|
205 |
-
)
|
206 |
return translation
|
207 |
except Exception as e:
|
208 |
logging.error(
|
@@ -275,23 +290,34 @@ def load_model_and_tokenizers():
|
|
275 |
logging.info("Configuration loaded.")
|
276 |
# --- Validate essential config keys ---
|
277 |
required_keys = [
|
278 |
-
"src_vocab_size",
|
279 |
-
"tgt_vocab_size",
|
280 |
-
"emb_size",
|
281 |
-
"
|
282 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
]
|
284 |
# Remap if needed (example shown, adjust if your keys differ)
|
285 |
config_key_mapping = {
|
286 |
-
"src_vocab_size": config.get(
|
287 |
-
|
|
|
|
|
|
|
|
|
288 |
# Add other mappings if necessary
|
289 |
}
|
290 |
config.update(config_key_mapping)
|
291 |
|
292 |
missing_keys = [key for key in required_keys if config.get(key) is None]
|
293 |
if missing_keys:
|
294 |
-
|
295 |
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
|
296 |
f"Ensure these were saved in the hyperparameters during training."
|
297 |
)
|
@@ -307,7 +333,9 @@ def load_model_and_tokenizers():
|
|
307 |
raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
|
308 |
except json.JSONDecodeError as e:
|
309 |
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
|
310 |
-
raise gr.Error(
|
|
|
|
|
311 |
except ValueError as e:
|
312 |
logging.error(f"Config validation error: {e}")
|
313 |
raise gr.Error(f"Config Error: {e}")
|
@@ -329,17 +357,26 @@ def load_model_and_tokenizers():
|
|
329 |
unk_token = "<unk>"
|
330 |
issues = []
|
331 |
# ... (keep the validation checks as in the original code) ...
|
332 |
-
if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
333 |
-
|
334 |
-
if
|
335 |
-
|
336 |
-
if iupac_tokenizer.token_to_id(
|
337 |
-
|
338 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
except Exception as e:
|
341 |
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
|
342 |
-
raise gr.Error(
|
|
|
|
|
343 |
|
344 |
# Load model
|
345 |
logging.info("Loading model from checkpoint...")
|
@@ -352,9 +389,9 @@ def load_model_and_tokenizers():
|
|
352 |
tgt_vocab_size=config["tgt_vocab_size"],
|
353 |
# Pass the whole config dict, load_from_checkpoint will pick what it needs
|
354 |
# if the keys match the __init__ args or are in hparams
|
355 |
-
**config,
|
356 |
map_location=device,
|
357 |
-
strict=True,
|
358 |
)
|
359 |
|
360 |
model.to(device)
|
@@ -366,24 +403,36 @@ def load_model_and_tokenizers():
|
|
366 |
|
367 |
except FileNotFoundError:
|
368 |
logging.error(f"Checkpoint file not found: {checkpoint_path}")
|
369 |
-
raise gr.Error(
|
|
|
|
|
370 |
except Exception as e:
|
371 |
-
logging.error(
|
|
|
|
|
372 |
if "size mismatch" in str(e):
|
373 |
-
error_detail =
|
374 |
logging.error(error_detail)
|
375 |
raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
|
376 |
elif "memory" in str(e).lower():
|
377 |
logging.warning("Potential OOM error during model loading.")
|
378 |
-
gc.collect()
|
379 |
-
|
|
|
|
|
|
|
380 |
else:
|
381 |
-
raise gr.Error(
|
|
|
|
|
382 |
|
383 |
-
except gr.Error:
|
|
|
384 |
except Exception as e:
|
385 |
logging.error(f"Unexpected error during loading: {e}", exc_info=True)
|
386 |
-
raise gr.Error(
|
|
|
|
|
387 |
|
388 |
|
389 |
# --- Inference Function for Gradio (Simplified) ---
|
@@ -396,11 +445,11 @@ def predict_iupac(smiles_string):
|
|
396 |
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
397 |
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
|
398 |
logging.error(error_msg)
|
399 |
-
return f"Error: {error_msg}"
|
400 |
|
401 |
if not smiles_string or not smiles_string.strip():
|
402 |
error_msg = "Error: Please enter a valid SMILES string."
|
403 |
-
return f"Error: {error_msg}"
|
404 |
|
405 |
smiles_input = smiles_string.strip()
|
406 |
|
@@ -411,7 +460,7 @@ def predict_iupac(smiles_string):
|
|
411 |
pad_idx = config["pad_token_id"]
|
412 |
gen_max_len = config["max_len"]
|
413 |
|
414 |
-
predicted_name = translate(
|
415 |
model=model,
|
416 |
src_sentence=smiles_input,
|
417 |
smiles_tokenizer=smiles_tokenizer,
|
@@ -425,10 +474,12 @@ def predict_iupac(smiles_string):
|
|
425 |
logging.info(f"Prediction returned: {predicted_name}")
|
426 |
|
427 |
# --- Format Output ---
|
428 |
-
if "[Error]" in predicted_name:
|
429 |
-
|
|
|
|
|
430 |
elif not predicted_name:
|
431 |
-
|
432 |
else:
|
433 |
output_text = (
|
434 |
f"Input SMILES: {smiles_input}\n\n"
|
@@ -439,12 +490,12 @@ def predict_iupac(smiles_string):
|
|
439 |
|
440 |
except RuntimeError as e:
|
441 |
logging.error(f"Runtime error during translation: {e}", exc_info=True)
|
442 |
-
return f"Error: {error_msg}"
|
443 |
|
444 |
except Exception as e:
|
445 |
logging.error(f"Unexpected error during translation: {e}", exc_info=True)
|
446 |
error_msg = f"Unexpected Error during translation: {e}"
|
447 |
-
return f"Error: {error_msg}"
|
448 |
|
449 |
|
450 |
# --- Load Model on App Start ---
|
@@ -452,7 +503,7 @@ try:
|
|
452 |
load_model_and_tokenizers()
|
453 |
except gr.Error as ge:
|
454 |
logging.error(f"Gradio Initialization Error: {ge}")
|
455 |
-
pass
|
456 |
except Exception as e:
|
457 |
logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
|
458 |
# Optionally raise gr.Error here too
|
@@ -463,7 +514,7 @@ title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
|
|
463 |
description = f"""
|
464 |
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
|
465 |
Translation uses **greedy decoding** (picks the most likely next word at each step).
|
466 |
-
**Note:** Model loaded on **{str(device).upper() if device else
|
467 |
"""
|
468 |
|
469 |
# Define examples - remove beam search parameters
|
@@ -481,16 +532,19 @@ smiles_input = gr.Textbox(
|
|
481 |
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
|
482 |
lines=1,
|
483 |
)
|
484 |
-
|
|
|
485 |
# Output component
|
486 |
output_text = gr.Textbox(
|
487 |
-
label="Predicted IUPAC Name",
|
|
|
|
|
488 |
)
|
489 |
|
490 |
# Create the interface instance
|
491 |
iface = gr.Interface(
|
492 |
fn=predict_iupac, # The function to call
|
493 |
-
inputs=smiles_input,
|
494 |
outputs=output_text, # Output component
|
495 |
title=title,
|
496 |
description=description,
|
@@ -505,4 +559,4 @@ iface = gr.Interface(
|
|
505 |
|
506 |
# --- Launch the App ---
|
507 |
if __name__ == "__main__":
|
508 |
-
iface.launch()
|
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
+
|
5 |
# import torch.nn.functional as F # No longer needed for greedy decode directly
|
6 |
import pytorch_lightning as pl
|
7 |
import os
|
|
|
29 |
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
|
30 |
try:
|
31 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
32 |
+
|
33 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
34 |
except ImportError as e:
|
35 |
logging.error(
|
|
|
82 |
|
83 |
# --- Initialize Target Sequence ---
|
84 |
# Start with the SOS token
|
85 |
+
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
86 |
+
sos_idx
|
87 |
+
) # [1, 1]
|
88 |
|
89 |
# --- Decoding Loop ---
|
90 |
+
for _ in range(max_len - 1): # Max length limit
|
91 |
tgt_seq_len = ys.shape[1]
|
92 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
93 |
device
|
|
|
108 |
|
109 |
# Get logits for the *next* token prediction
|
110 |
next_token_logits = transformer_model.generator(
|
111 |
+
decoder_output[
|
112 |
+
:, -1, :
|
113 |
+
] # Use output corresponding to the last input token
|
114 |
) # [1, tgt_vocab_size]
|
115 |
|
116 |
# Find the most likely next token (greedy choice)
|
117 |
# prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
|
118 |
# _, next_word_id_tensor = torch.max(prob, dim=1)
|
119 |
+
next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
|
120 |
next_word_id = next_word_id_tensor.item()
|
121 |
|
122 |
# Append the chosen token to the sequence
|
123 |
ys = torch.cat(
|
124 |
+
[
|
125 |
+
ys,
|
126 |
+
torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
127 |
+
next_word_id
|
128 |
+
),
|
129 |
+
],
|
130 |
+
dim=1,
|
131 |
+
) # [1, current_len + 1]
|
132 |
|
133 |
# Stop if EOS token is generated
|
134 |
if next_word_id == eos_idx:
|
135 |
break
|
136 |
|
137 |
# Return the generated sequence (excluding the initial SOS token)
|
138 |
+
return ys[:, 1:] # Shape [1, generated_len]
|
139 |
|
140 |
except RuntimeError as e:
|
141 |
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
|
142 |
if "CUDA out of memory" in str(e) and device.type == "cuda":
|
143 |
gc.collect()
|
144 |
torch.cuda.empty_cache()
|
145 |
+
return torch.empty(
|
146 |
+
(1, 0), dtype=torch.long, device=device
|
147 |
+
) # Return empty tensor on error
|
148 |
except Exception as e:
|
149 |
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
|
150 |
return torch.empty((1, 0), dtype=torch.long, device=device)
|
|
|
161 |
sos_idx: int,
|
162 |
eos_idx: int,
|
163 |
pad_idx: int,
|
164 |
+
) -> str: # Returns a single string
|
165 |
"""
|
166 |
Translates a single SMILES string using greedy decoding.
|
167 |
"""
|
|
|
170 |
# --- Tokenize Source ---
|
171 |
try:
|
172 |
# Ensure tokenizer has truncation/padding configured if needed, or handle manually
|
173 |
+
smiles_tokenizer.enable_truncation(
|
174 |
+
max_length=max_len
|
175 |
+
) # Use max_len for source truncation too
|
176 |
src_encoded = smiles_tokenizer.encode(src_sentence)
|
177 |
if not src_encoded or not src_encoded.ids:
|
178 |
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
|
|
209 |
|
210 |
# --- Decode Generated Tokens ---
|
211 |
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
|
212 |
+
logging.warning(
|
213 |
+
f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
|
214 |
+
)
|
215 |
return "[Decoding Error - Empty Output]"
|
216 |
|
217 |
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
218 |
try:
|
219 |
# Decode using the target tokenizer, skipping special tokens
|
220 |
+
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
|
|
|
|
221 |
return translation
|
222 |
except Exception as e:
|
223 |
logging.error(
|
|
|
290 |
logging.info("Configuration loaded.")
|
291 |
# --- Validate essential config keys ---
|
292 |
required_keys = [
|
293 |
+
"src_vocab_size", # Use the key saved in config
|
294 |
+
"tgt_vocab_size", # Use the key saved in config
|
295 |
+
"emb_size",
|
296 |
+
"nhead",
|
297 |
+
"ffn_hid_dim",
|
298 |
+
"num_encoder_layers",
|
299 |
+
"num_decoder_layers",
|
300 |
+
"dropout",
|
301 |
+
"max_len",
|
302 |
+
"pad_token_id",
|
303 |
+
"bos_token_id",
|
304 |
+
"eos_token_id",
|
305 |
]
|
306 |
# Remap if needed (example shown, adjust if your keys differ)
|
307 |
config_key_mapping = {
|
308 |
+
"src_vocab_size": config.get(
|
309 |
+
"src_vocab_size", config.get("actual_src_vocab_size")
|
310 |
+
),
|
311 |
+
"tgt_vocab_size": config.get(
|
312 |
+
"tgt_vocab_size", config.get("actual_tgt_vocab_size")
|
313 |
+
),
|
314 |
# Add other mappings if necessary
|
315 |
}
|
316 |
config.update(config_key_mapping)
|
317 |
|
318 |
missing_keys = [key for key in required_keys if config.get(key) is None]
|
319 |
if missing_keys:
|
320 |
+
raise ValueError(
|
321 |
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
|
322 |
f"Ensure these were saved in the hyperparameters during training."
|
323 |
)
|
|
|
333 |
raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found.")
|
334 |
except json.JSONDecodeError as e:
|
335 |
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
|
336 |
+
raise gr.Error(
|
337 |
+
f"Config Error: Could not parse '{CONFIG_FILENAME}'. Error: {e}"
|
338 |
+
)
|
339 |
except ValueError as e:
|
340 |
logging.error(f"Config validation error: {e}")
|
341 |
raise gr.Error(f"Config Error: {e}")
|
|
|
357 |
unk_token = "<unk>"
|
358 |
issues = []
|
359 |
# ... (keep the validation checks as in the original code) ...
|
360 |
+
if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
361 |
+
issues.append(f"SMILES PAD ID mismatch")
|
362 |
+
if smiles_tokenizer.token_to_id(unk_token) is None:
|
363 |
+
issues.append("SMILES UNK token not found")
|
364 |
+
if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
365 |
+
issues.append(f"IUPAC PAD ID mismatch")
|
366 |
+
if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
|
367 |
+
issues.append(f"IUPAC SOS ID mismatch")
|
368 |
+
if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
|
369 |
+
issues.append(f"IUPAC EOS ID mismatch")
|
370 |
+
if iupac_tokenizer.token_to_id(unk_token) is None:
|
371 |
+
issues.append("IUPAC UNK token not found")
|
372 |
+
if issues:
|
373 |
+
logging.warning("Tokenizer validation issues: " + "; ".join(issues))
|
374 |
|
375 |
except Exception as e:
|
376 |
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
|
377 |
+
raise gr.Error(
|
378 |
+
f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
|
379 |
+
)
|
380 |
|
381 |
# Load model
|
382 |
logging.info("Loading model from checkpoint...")
|
|
|
389 |
tgt_vocab_size=config["tgt_vocab_size"],
|
390 |
# Pass the whole config dict, load_from_checkpoint will pick what it needs
|
391 |
# if the keys match the __init__ args or are in hparams
|
392 |
+
**config, # Pass loaded config directly as keyword args
|
393 |
map_location=device,
|
394 |
+
strict=True, # Start strict, set to False if encountering key errors
|
395 |
)
|
396 |
|
397 |
model.to(device)
|
|
|
403 |
|
404 |
except FileNotFoundError:
|
405 |
logging.error(f"Checkpoint file not found: {checkpoint_path}")
|
406 |
+
raise gr.Error(
|
407 |
+
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
|
408 |
+
)
|
409 |
except Exception as e:
|
410 |
+
logging.error(
|
411 |
+
f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
|
412 |
+
)
|
413 |
if "size mismatch" in str(e):
|
414 |
+
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
|
415 |
logging.error(error_detail)
|
416 |
raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
|
417 |
elif "memory" in str(e).lower():
|
418 |
logging.warning("Potential OOM error during model loading.")
|
419 |
+
gc.collect()
|
420 |
+
torch.cuda.empty_cache() if device.type == "cuda" else None
|
421 |
+
raise gr.Error(
|
422 |
+
f"Model Error: OOM loading model. Check Space resources. Error: {e}"
|
423 |
+
)
|
424 |
else:
|
425 |
+
raise gr.Error(
|
426 |
+
f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
|
427 |
+
)
|
428 |
|
429 |
+
except gr.Error:
|
430 |
+
raise
|
431 |
except Exception as e:
|
432 |
logging.error(f"Unexpected error during loading: {e}", exc_info=True)
|
433 |
+
raise gr.Error(
|
434 |
+
f"Initialization Error: Unexpected error. Check logs. Error: {e}"
|
435 |
+
)
|
436 |
|
437 |
|
438 |
# --- Inference Function for Gradio (Simplified) ---
|
|
|
445 |
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
446 |
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
|
447 |
logging.error(error_msg)
|
448 |
+
return f"Error: {error_msg}" # Return single error string
|
449 |
|
450 |
if not smiles_string or not smiles_string.strip():
|
451 |
error_msg = "Error: Please enter a valid SMILES string."
|
452 |
+
return f"Error: {error_msg}" # Return single error string
|
453 |
|
454 |
smiles_input = smiles_string.strip()
|
455 |
|
|
|
460 |
pad_idx = config["pad_token_id"]
|
461 |
gen_max_len = config["max_len"]
|
462 |
|
463 |
+
predicted_name = translate( # Returns a single string now
|
464 |
model=model,
|
465 |
src_sentence=smiles_input,
|
466 |
smiles_tokenizer=smiles_tokenizer,
|
|
|
474 |
logging.info(f"Prediction returned: {predicted_name}")
|
475 |
|
476 |
# --- Format Output ---
|
477 |
+
if "[Error]" in predicted_name: # Check for error messages from translate
|
478 |
+
output_text = (
|
479 |
+
f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
|
480 |
+
)
|
481 |
elif not predicted_name:
|
482 |
+
output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
|
483 |
else:
|
484 |
output_text = (
|
485 |
f"Input SMILES: {smiles_input}\n\n"
|
|
|
490 |
|
491 |
except RuntimeError as e:
|
492 |
logging.error(f"Runtime error during translation: {e}", exc_info=True)
|
493 |
+
return f"Error: {error_msg}" # Return single error string
|
494 |
|
495 |
except Exception as e:
|
496 |
logging.error(f"Unexpected error during translation: {e}", exc_info=True)
|
497 |
error_msg = f"Unexpected Error during translation: {e}"
|
498 |
+
return f"Error: {error_msg}" # Return single error string
|
499 |
|
500 |
|
501 |
# --- Load Model on App Start ---
|
|
|
503 |
load_model_and_tokenizers()
|
504 |
except gr.Error as ge:
|
505 |
logging.error(f"Gradio Initialization Error: {ge}")
|
506 |
+
pass # Allow Gradio to potentially start with an error message
|
507 |
except Exception as e:
|
508 |
logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
|
509 |
# Optionally raise gr.Error here too
|
|
|
514 |
description = f"""
|
515 |
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
|
516 |
Translation uses **greedy decoding** (picks the most likely next word at each step).
|
517 |
+
**Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
|
518 |
"""
|
519 |
|
520 |
# Define examples - remove beam search parameters
|
|
|
532 |
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
|
533 |
lines=1,
|
534 |
)
|
535 |
+
from rdkit.Chem import CanonSmiles
|
536 |
+
smiles_input = CanonSmiles(smiles_input)
|
537 |
# Output component
|
538 |
output_text = gr.Textbox(
|
539 |
+
label="Predicted IUPAC Name",
|
540 |
+
lines=3,
|
541 |
+
show_copy_button=True, # Reduced lines slightly
|
542 |
)
|
543 |
|
544 |
# Create the interface instance
|
545 |
iface = gr.Interface(
|
546 |
fn=predict_iupac, # The function to call
|
547 |
+
inputs=smiles_input, # Single input component
|
548 |
outputs=output_text, # Output component
|
549 |
title=title,
|
550 |
description=description,
|
|
|
559 |
|
560 |
# --- Launch the App ---
|
561 |
if __name__ == "__main__":
|
562 |
+
iface.launch()
|