Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
-
|
5 |
# import torch.nn.functional as F # No longer needed for greedy decode directly
|
6 |
import pytorch_lightning as pl
|
7 |
import os
|
@@ -10,8 +9,13 @@ import logging
|
|
10 |
from tokenizers import Tokenizer
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
import gc
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# --- Configuration ---
|
17 |
MODEL_REPO_ID = (
|
@@ -30,24 +34,24 @@ logging.basicConfig(
|
|
30 |
|
31 |
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
|
32 |
try:
|
|
|
33 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
34 |
-
|
35 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
36 |
except ImportError as e:
|
37 |
logging.error(
|
38 |
f"Failed to import helper code from enhanced_trainer.py: {e}. "
|
39 |
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
|
40 |
)
|
41 |
-
raise gr.Error
|
42 |
-
|
43 |
-
|
44 |
except Exception as e:
|
45 |
logging.error(
|
46 |
f"An unexpected error occurred during helper code import: {e}", exc_info=True
|
47 |
)
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
# --- Global Variables (Load Model Once) ---
|
53 |
model: pl.LightningModule | None = None
|
@@ -69,87 +73,107 @@ def greedy_decode(
|
|
69 |
) -> torch.Tensor:
|
70 |
"""
|
71 |
Performs greedy decoding using the LightningModule's model.
|
|
|
72 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
model.eval() # Ensure model is in evaluation mode
|
74 |
transformer_model = model.model # Access the underlying Seq2SeqTransformer
|
75 |
|
76 |
try:
|
77 |
with torch.no_grad():
|
78 |
# --- Encode Source ---
|
|
|
79 |
memory = transformer_model.encode(
|
80 |
-
src,
|
81 |
-
)
|
|
|
|
|
82 |
memory = memory.to(device)
|
83 |
-
|
|
|
|
|
84 |
|
85 |
# --- Initialize Target Sequence ---
|
86 |
-
# Start with the SOS token
|
87 |
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
88 |
sos_idx
|
89 |
-
)
|
90 |
|
91 |
# --- Decoding Loop ---
|
92 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
tgt_seq_len = ys.shape[1]
|
|
|
94 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
95 |
device
|
96 |
-
)
|
97 |
-
|
|
|
98 |
tgt_padding_mask = torch.zeros(
|
99 |
ys.shape, dtype=torch.bool, device=device
|
100 |
-
)
|
|
|
|
|
|
|
|
|
101 |
|
102 |
# Decode one step
|
103 |
decoder_output = transformer_model.decode(
|
104 |
-
tgt=
|
105 |
-
memory=memory,
|
106 |
-
tgt_mask=tgt_mask,
|
107 |
-
|
108 |
-
memory_key_padding_mask=memory_key_padding_mask,
|
109 |
-
)
|
110 |
|
111 |
# Get logits for the *next* token prediction
|
|
|
112 |
next_token_logits = transformer_model.generator(
|
113 |
-
decoder_output[
|
114 |
-
|
115 |
-
] # Use output corresponding to the last input token
|
116 |
-
) # [1, tgt_vocab_size]
|
117 |
|
118 |
# Find the most likely next token (greedy choice)
|
119 |
-
|
120 |
-
# _, next_word_id_tensor = torch.max(prob, dim=1)
|
121 |
-
next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
|
122 |
next_word_id = next_word_id_tensor.item()
|
123 |
|
124 |
-
# Append the chosen token to the sequence
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
next_word_id
|
130 |
-
),
|
131 |
-
],
|
132 |
-
dim=1,
|
133 |
-
) # [1, current_len + 1]
|
134 |
|
135 |
# Stop if EOS token is generated
|
136 |
if next_word_id == eos_idx:
|
137 |
break
|
138 |
|
139 |
# Return the generated sequence (excluding the initial SOS token)
|
140 |
-
|
|
|
141 |
|
142 |
except RuntimeError as e:
|
143 |
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
|
144 |
if "CUDA out of memory" in str(e) and device.type == "cuda":
|
145 |
gc.collect()
|
146 |
torch.cuda.empty_cache()
|
147 |
-
|
148 |
-
|
149 |
-
) # Return empty tensor on error
|
150 |
except Exception as e:
|
151 |
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
|
152 |
-
|
153 |
|
154 |
|
155 |
# --- Translation Function (Using Greedy Decode) ---
|
@@ -159,94 +183,108 @@ def translate(
|
|
159 |
smiles_tokenizer: Tokenizer,
|
160 |
iupac_tokenizer: Tokenizer,
|
161 |
device: torch.device,
|
162 |
-
|
163 |
sos_idx: int,
|
164 |
eos_idx: int,
|
165 |
pad_idx: int,
|
166 |
-
) -> str: # Returns a single string
|
167 |
"""
|
168 |
Translates a single SMILES string using greedy decoding.
|
169 |
"""
|
|
|
|
|
|
|
170 |
model.eval() # Ensure model is in eval mode
|
171 |
|
172 |
# --- Tokenize Source ---
|
173 |
try:
|
174 |
-
# Ensure tokenizer has truncation
|
175 |
-
smiles_tokenizer.enable_truncation(
|
176 |
-
|
177 |
-
|
178 |
src_encoded = smiles_tokenizer.encode(src_sentence)
|
179 |
if not src_encoded or not src_encoded.ids:
|
180 |
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
181 |
-
return "[Encoding Error]"
|
182 |
-
# Use the truncated IDs directly
|
183 |
src_ids = src_encoded.ids
|
|
|
|
|
|
|
|
|
|
|
184 |
except Exception as e:
|
185 |
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
|
186 |
-
return "[Encoding Error]"
|
187 |
|
188 |
# --- Prepare Input Tensor and Mask ---
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
|
194 |
|
195 |
# --- Perform Greedy Decoding ---
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
|
212 |
# --- Decode Generated Tokens ---
|
213 |
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
220 |
try:
|
221 |
# Decode using the target tokenizer, skipping special tokens
|
222 |
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
223 |
-
return translation
|
224 |
except Exception as e:
|
225 |
logging.error(
|
226 |
f"Error decoding target tokens {tgt_tokens}: {e}",
|
227 |
exc_info=True,
|
228 |
)
|
229 |
-
return "[Decoding Error]"
|
230 |
|
231 |
|
232 |
-
# --- Model/Tokenizer Loading Function
|
233 |
def load_model_and_tokenizers():
|
234 |
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
235 |
-
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
236 |
if model is not None: # Already loaded
|
237 |
logging.info("Model and tokenizers already loaded.")
|
238 |
return
|
239 |
|
240 |
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
try:
|
242 |
# Determine device
|
243 |
if torch.cuda.is_available():
|
244 |
-
|
245 |
-
|
246 |
-
)
|
247 |
-
device = torch.device("cuda") # Uncomment if GPU is intended
|
248 |
-
# device = torch.device("cuda")
|
249 |
-
# logging.info("CUDA available, using GPU.")
|
250 |
else:
|
251 |
device = torch.device("cpu")
|
252 |
logging.info("CUDA not available, using CPU.")
|
@@ -254,27 +292,25 @@ def load_model_and_tokenizers():
|
|
254 |
# Download files
|
255 |
logging.info("Downloading files from Hugging Face Hub...")
|
256 |
try:
|
|
|
257 |
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
|
258 |
-
os.makedirs(cache_dir, exist_ok=True)
|
259 |
logging.info(f"Using cache directory: {cache_dir}")
|
260 |
|
|
|
261 |
checkpoint_path = hf_hub_download(
|
262 |
-
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
|
263 |
)
|
264 |
smiles_tokenizer_path = hf_hub_download(
|
265 |
-
repo_id=MODEL_REPO_ID,
|
266 |
-
filename=SMILES_TOKENIZER_FILENAME,
|
267 |
-
cache_dir=cache_dir,
|
268 |
)
|
269 |
iupac_tokenizer_path = hf_hub_download(
|
270 |
-
repo_id=MODEL_REPO_ID,
|
271 |
-
filename=IUPAC_TOKENIZER_FILENAME,
|
272 |
-
cache_dir=cache_dir,
|
273 |
)
|
274 |
config_path = hf_hub_download(
|
275 |
-
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
|
276 |
)
|
277 |
-
logging.info("Files downloaded successfully.")
|
278 |
except Exception as e:
|
279 |
logging.error(
|
280 |
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
|
@@ -292,30 +328,35 @@ def load_model_and_tokenizers():
|
|
292 |
logging.info("Configuration loaded.")
|
293 |
# --- Validate essential config keys ---
|
294 |
required_keys = [
|
295 |
-
"src_vocab_size",
|
296 |
-
"tgt_vocab_size",
|
297 |
"emb_size",
|
298 |
"nhead",
|
299 |
"ffn_hid_dim",
|
300 |
"num_encoder_layers",
|
301 |
"num_decoder_layers",
|
302 |
"dropout",
|
303 |
-
"max_len",
|
304 |
"pad_token_id",
|
305 |
"bos_token_id",
|
306 |
"eos_token_id",
|
307 |
]
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
config.
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
missing_keys = [key for key in required_keys if config.get(key) is None]
|
321 |
if missing_keys:
|
@@ -325,7 +366,7 @@ def load_model_and_tokenizers():
|
|
325 |
)
|
326 |
|
327 |
logging.info(
|
328 |
-
f"Using config
|
329 |
f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
|
330 |
f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
|
331 |
)
|
@@ -350,51 +391,75 @@ def load_model_and_tokenizers():
|
|
350 |
try:
|
351 |
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
|
352 |
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
353 |
-
|
354 |
-
# ---
|
355 |
-
# (Keep validation as before, it's still useful)
|
356 |
pad_token = "<pad>"
|
357 |
sos_token = "<sos>"
|
358 |
eos_token = "<eos>"
|
359 |
-
unk_token = "<unk>"
|
360 |
issues = []
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
if issues:
|
375 |
-
logging.warning("Tokenizer validation issues: " + "
|
|
|
|
|
|
|
|
|
376 |
|
377 |
except Exception as e:
|
378 |
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
|
379 |
raise gr.Error(
|
380 |
-
f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
|
381 |
)
|
382 |
|
383 |
# Load model
|
384 |
logging.info("Loading model from checkpoint...")
|
385 |
try:
|
386 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
model = SmilesIupacLitModule.load_from_checkpoint(
|
388 |
checkpoint_path,
|
389 |
-
**config, # Pass loaded config directly as keyword args
|
390 |
map_location=device,
|
391 |
-
|
392 |
-
|
|
|
|
|
393 |
)
|
394 |
|
|
|
395 |
model.to(device)
|
396 |
model.eval()
|
397 |
-
model.freeze()
|
398 |
logging.info(
|
399 |
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
|
400 |
)
|
@@ -402,172 +467,203 @@ def load_model_and_tokenizers():
|
|
402 |
except FileNotFoundError:
|
403 |
logging.error(f"Checkpoint file not found: {checkpoint_path}")
|
404 |
raise gr.Error(
|
405 |
-
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
|
406 |
)
|
407 |
-
except
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
logging.error(
|
409 |
-
f"
|
|
|
|
|
|
|
410 |
)
|
411 |
-
if "size mismatch" in str(e):
|
412 |
-
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
|
413 |
-
logging.error(error_detail)
|
414 |
-
raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
|
415 |
-
elif "memory" in str(e).lower():
|
416 |
-
logging.warning("Potential OOM error during model loading.")
|
417 |
-
gc.collect()
|
418 |
-
torch.cuda.empty_cache() if device.type == "cuda" else None
|
419 |
-
raise gr.Error(
|
420 |
-
f"Model Error: OOM loading model. Check Space resources. Error: {e}"
|
421 |
-
)
|
422 |
-
else:
|
423 |
-
raise gr.Error(
|
424 |
-
f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
|
425 |
-
)
|
426 |
|
427 |
-
except gr.Error:
|
428 |
-
raise
|
429 |
-
except Exception as e:
|
430 |
-
logging.error(f"Unexpected error during loading: {e}", exc_info=True)
|
431 |
raise gr.Error(
|
432 |
-
f"Initialization Error: Unexpected
|
433 |
)
|
434 |
|
435 |
|
436 |
-
# --- Inference Function for Gradio
|
437 |
def predict_iupac(smiles_string):
|
438 |
"""
|
439 |
Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
|
|
|
440 |
"""
|
441 |
-
try:
|
442 |
-
smiles_string = CanonSmiles(smiles_string)
|
443 |
-
except Exception as e:
|
444 |
-
logging.error(f"Error during SMILES canonicalization: {e}", exc_info=True)
|
445 |
-
return f"Error: {e}"
|
446 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
447 |
|
|
|
448 |
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
449 |
-
error_msg = "Error: Model or tokenizers not loaded
|
450 |
logging.error(error_msg)
|
451 |
-
|
|
|
452 |
|
|
|
453 |
if not smiles_string or not smiles_string.strip():
|
454 |
-
|
455 |
-
return f"Error: {error_msg}" # Return single error string
|
456 |
|
457 |
-
|
458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
try:
|
460 |
-
# --- Call the core translation logic (greedy) ---
|
461 |
sos_idx = config["bos_token_id"]
|
462 |
eos_idx = config["eos_token_id"]
|
463 |
pad_idx = config["pad_token_id"]
|
464 |
-
gen_max_len = config["max_len"]
|
465 |
|
466 |
-
predicted_name = translate(
|
467 |
model=model,
|
468 |
-
src_sentence=
|
469 |
smiles_tokenizer=smiles_tokenizer,
|
470 |
iupac_tokenizer=iupac_tokenizer,
|
471 |
device=device,
|
472 |
-
|
473 |
sos_idx=sos_idx,
|
474 |
eos_idx=eos_idx,
|
475 |
pad_idx=pad_idx,
|
476 |
)
|
477 |
-
logging.info(f"Prediction
|
478 |
|
479 |
# --- Format Output ---
|
480 |
-
|
|
|
|
|
481 |
output_text = (
|
482 |
-
f"Input SMILES: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
)
|
484 |
-
elif not predicted_name:
|
485 |
-
output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
|
486 |
else:
|
487 |
output_text = (
|
488 |
-
f"Input SMILES: {
|
|
|
489 |
f"Predicted IUPAC Name (Greedy Decode):\n"
|
490 |
f"{predicted_name}"
|
491 |
)
|
492 |
-
|
|
|
|
|
493 |
|
494 |
-
|
495 |
-
logging.error(f"Runtime error during translation: {e}", exc_info=True)
|
496 |
-
return f"Error: {error_msg}" # Return single error string
|
497 |
|
498 |
except Exception as e:
|
499 |
-
|
500 |
-
|
501 |
-
|
|
|
502 |
|
503 |
|
504 |
# --- Load Model on App Start ---
|
|
|
|
|
505 |
try:
|
506 |
load_model_and_tokenizers()
|
507 |
except gr.Error as ge:
|
508 |
-
logging.error(f"Gradio Initialization Error: {ge}")
|
509 |
-
|
510 |
except Exception as e:
|
511 |
-
logging.
|
512 |
-
|
513 |
|
514 |
|
515 |
-
# --- Create Gradio Interface
|
516 |
title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
|
517 |
description = f"""
|
518 |
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
|
519 |
Translation uses **greedy decoding** (picks the most likely next word at each step).
|
520 |
-
**Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
|
521 |
"""
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
|
|
|
|
|
|
|
|
526 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
|
527 |
gr.Markdown(f"# {title}")
|
528 |
gr.Markdown(description)
|
529 |
-
|
530 |
with gr.Row():
|
531 |
-
with gr.Column():
|
532 |
smiles_input = gr.Textbox(
|
533 |
label="SMILES String",
|
534 |
-
placeholder="Enter SMILES string
|
535 |
-
lines=
|
536 |
)
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
with gr.Column():
|
542 |
output_text = gr.Textbox(
|
543 |
-
label="
|
544 |
-
lines=
|
545 |
show_copy_button=True,
|
|
|
546 |
)
|
547 |
-
|
548 |
-
#
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
#
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
|
560 |
|
561 |
-
# Create the interface instance
|
562 |
-
iface = gr.Interface(
|
563 |
-
fn=predict_iupac, # The function to call
|
564 |
-
inputs=smiles_input, # Input component directly
|
565 |
-
outputs=output_text, # Output component
|
566 |
-
title=title,
|
567 |
-
description=description,
|
568 |
-
allow_flagging="never",
|
569 |
-
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
|
570 |
-
)
|
571 |
# --- Launch the App ---
|
572 |
if __name__ == "__main__":
|
573 |
-
|
|
|
|
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
import torch
|
|
|
4 |
# import torch.nn.functional as F # No longer needed for greedy decode directly
|
5 |
import pytorch_lightning as pl
|
6 |
import os
|
|
|
9 |
from tokenizers import Tokenizer
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import gc
|
12 |
+
try:
|
13 |
+
from rdkit import Chem
|
14 |
+
from rdkit import RDLogger # Optional: To suppress RDKit logs
|
15 |
+
RDLogger.DisableLog('rdApp.*') # Suppress RDKit warnings/errors if desired
|
16 |
+
except ImportError:
|
17 |
+
logging.warning("RDKit not found. SMILES canonicalization will be skipped. Install with 'pip install rdkit'")
|
18 |
+
Chem = None # Set Chem to None if RDKit is not available
|
19 |
|
20 |
# --- Configuration ---
|
21 |
MODEL_REPO_ID = (
|
|
|
34 |
|
35 |
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
|
36 |
try:
|
37 |
+
# Ensure enhanced_trainer.py is in the root directory of your space repo
|
38 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
|
|
39 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
40 |
except ImportError as e:
|
41 |
logging.error(
|
42 |
f"Failed to import helper code from enhanced_trainer.py: {e}. "
|
43 |
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
|
44 |
)
|
45 |
+
# We raise gr.Error later during loading if the class isn't found
|
46 |
+
SmilesIupacLitModule = None
|
47 |
+
generate_square_subsequent_mask = None
|
48 |
except Exception as e:
|
49 |
logging.error(
|
50 |
f"An unexpected error occurred during helper code import: {e}", exc_info=True
|
51 |
)
|
52 |
+
SmilesIupacLitModule = None
|
53 |
+
generate_square_subsequent_mask = None
|
54 |
+
|
55 |
|
56 |
# --- Global Variables (Load Model Once) ---
|
57 |
model: pl.LightningModule | None = None
|
|
|
73 |
) -> torch.Tensor:
|
74 |
"""
|
75 |
Performs greedy decoding using the LightningModule's model.
|
76 |
+
Assumes model has 'model.encode', 'model.decode', 'model.generator' attributes.
|
77 |
"""
|
78 |
+
if not hasattr(model, 'model') or not hasattr(model.model, 'encode') or \
|
79 |
+
not hasattr(model.model, 'decode') or not hasattr(model.model, 'generator'):
|
80 |
+
logging.error("Model object does not have the expected 'model.encode/decode/generator' structure.")
|
81 |
+
raise AttributeError("Model structure mismatch for greedy decoding.")
|
82 |
+
|
83 |
+
if generate_square_subsequent_mask is None:
|
84 |
+
logging.error("generate_square_subsequent_mask function not imported.")
|
85 |
+
raise ImportError("generate_square_subsequent_mask is required for greedy_decode.")
|
86 |
+
|
87 |
+
|
88 |
model.eval() # Ensure model is in evaluation mode
|
89 |
transformer_model = model.model # Access the underlying Seq2SeqTransformer
|
90 |
|
91 |
try:
|
92 |
with torch.no_grad():
|
93 |
# --- Encode Source ---
|
94 |
+
# The mask should be True where the input *is* padding.
|
95 |
memory = transformer_model.encode(
|
96 |
+
src, src_mask=None # Standard transformer encoder doesn't usually use src_mask
|
97 |
+
) # [1, src_len, emb_size] if batch_first=True in TransformerEncoder
|
98 |
+
# If batch_first=False (default): [src_len, 1, emb_size] -> adjust usage below
|
99 |
+
# Assuming batch_first=False for standard nn.Transformer
|
100 |
memory = memory.to(device)
|
101 |
+
|
102 |
+
# Memory key padding mask needs to be [batch_size, src_len] -> [1, src_len]
|
103 |
+
memory_key_padding_mask = src_padding_mask.to(device) # [1, src_len]
|
104 |
|
105 |
# --- Initialize Target Sequence ---
|
106 |
+
# Start with the SOS token -> Shape [1, 1] (batch, seq)
|
107 |
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
108 |
sos_idx
|
109 |
+
)
|
110 |
|
111 |
# --- Decoding Loop ---
|
112 |
+
for i in range(max_len - 1): # Max length limit
|
113 |
+
# Target processing depends on whether decoder expects batch_first
|
114 |
+
# Standard nn.TransformerDecoder expects [tgt_len, batch_size, emb_size]
|
115 |
+
# Standard nn.TransformerDecoder expects tgt_mask [tgt_len, tgt_len]
|
116 |
+
# Standard nn.TransformerDecoder expects memory_key_padding_mask [batch_size, src_len]
|
117 |
+
# Standard nn.TransformerDecoder expects tgt_key_padding_mask [batch_size, tgt_len]
|
118 |
+
|
119 |
tgt_seq_len = ys.shape[1]
|
120 |
+
# Create causal mask -> [tgt_len, tgt_len]
|
121 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
122 |
device
|
123 |
+
)
|
124 |
+
|
125 |
+
# Target padding mask: False for non-pad tokens -> [1, tgt_len]
|
126 |
tgt_padding_mask = torch.zeros(
|
127 |
ys.shape, dtype=torch.bool, device=device
|
128 |
+
)
|
129 |
+
|
130 |
+
# Prepare target for decoder (assuming batch_first=False expected)
|
131 |
+
# Input ys is [1, current_len] -> need [current_len, 1]
|
132 |
+
ys_decoder_input = ys.transpose(0, 1).to(device) # [current_len, 1]
|
133 |
|
134 |
# Decode one step
|
135 |
decoder_output = transformer_model.decode(
|
136 |
+
tgt=ys_decoder_input, # [current_len, 1]
|
137 |
+
memory=memory, # [src_len, 1, emb_size]
|
138 |
+
tgt_mask=tgt_mask, # [current_len, current_len]
|
139 |
+
tgt_key_padding_mask=tgt_padding_mask, # [1, current_len]
|
140 |
+
memory_key_padding_mask=memory_key_padding_mask, # [1, src_len]
|
141 |
+
) # Output shape [current_len, 1, emb_size]
|
142 |
|
143 |
# Get logits for the *next* token prediction
|
144 |
+
# Use output corresponding to the last input token -> [-1, :, :]
|
145 |
next_token_logits = transformer_model.generator(
|
146 |
+
decoder_output[-1, :, :] # Shape [1, emb_size]
|
147 |
+
) # Output shape [1, tgt_vocab_size]
|
|
|
|
|
148 |
|
149 |
# Find the most likely next token (greedy choice)
|
150 |
+
next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # Shape [1]
|
|
|
|
|
151 |
next_word_id = next_word_id_tensor.item()
|
152 |
|
153 |
+
# Append the chosen token to the sequence (shape [1, 1])
|
154 |
+
next_word_tensor = torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id)
|
155 |
+
|
156 |
+
# Concatenate along the sequence dimension (dim=1)
|
157 |
+
ys = torch.cat([ys, next_word_tensor], dim=1) # [1, current_len + 1]
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
# Stop if EOS token is generated
|
160 |
if next_word_id == eos_idx:
|
161 |
break
|
162 |
|
163 |
# Return the generated sequence (excluding the initial SOS token)
|
164 |
+
# Shape [1, generated_len]
|
165 |
+
return ys[:, 1:]
|
166 |
|
167 |
except RuntimeError as e:
|
168 |
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
|
169 |
if "CUDA out of memory" in str(e) and device.type == "cuda":
|
170 |
gc.collect()
|
171 |
torch.cuda.empty_cache()
|
172 |
+
raise RuntimeError("CUDA out of memory during greedy decoding.") # Re-raise specific error
|
173 |
+
raise e # Re-raise other runtime errors
|
|
|
174 |
except Exception as e:
|
175 |
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
|
176 |
+
raise e # Re-raise
|
177 |
|
178 |
|
179 |
# --- Translation Function (Using Greedy Decode) ---
|
|
|
183 |
smiles_tokenizer: Tokenizer,
|
184 |
iupac_tokenizer: Tokenizer,
|
185 |
device: torch.device,
|
186 |
+
max_len_config: int, # Max length from config (used for source truncation & generation limit)
|
187 |
sos_idx: int,
|
188 |
eos_idx: int,
|
189 |
pad_idx: int,
|
190 |
+
) -> str: # Returns a single string or an error message
|
191 |
"""
|
192 |
Translates a single SMILES string using greedy decoding.
|
193 |
"""
|
194 |
+
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
195 |
+
return "[Initialization Error: Components not loaded]"
|
196 |
+
|
197 |
model.eval() # Ensure model is in eval mode
|
198 |
|
199 |
# --- Tokenize Source ---
|
200 |
try:
|
201 |
+
# Ensure tokenizer has truncation configured
|
202 |
+
smiles_tokenizer.enable_truncation(max_length=max_len_config)
|
203 |
+
smiles_tokenizer.enable_padding(pad_id=pad_idx, pad_token="<pad>", length=max_len_config) # Ensure padding for consistent input length if needed by model
|
204 |
+
|
205 |
src_encoded = smiles_tokenizer.encode(src_sentence)
|
206 |
if not src_encoded or not src_encoded.ids:
|
207 |
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
208 |
+
return "[Encoding Error: Empty result]"
|
|
|
209 |
src_ids = src_encoded.ids
|
210 |
+
# Use attention mask directly for padding mask (1 for real tokens, 0 for padding)
|
211 |
+
# We need the opposite for PyTorch Transformer (True for padding, False for real)
|
212 |
+
src_attention_mask = torch.tensor(src_encoded.attention_mask, dtype=torch.long)
|
213 |
+
src_padding_mask = (src_attention_mask == 0) # True where it's padded
|
214 |
+
|
215 |
except Exception as e:
|
216 |
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
|
217 |
+
return f"[Encoding Error: {e}]"
|
218 |
|
219 |
# --- Prepare Input Tensor and Mask ---
|
220 |
+
# Input tensor shape [1, src_len]
|
221 |
+
src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
|
222 |
+
# Padding mask shape [1, src_len]
|
223 |
+
src_padding_mask = src_padding_mask.unsqueeze(0).to(device)
|
|
|
224 |
|
225 |
# --- Perform Greedy Decoding ---
|
226 |
+
try:
|
227 |
+
tgt_tokens_tensor = greedy_decode(
|
228 |
+
model=model,
|
229 |
+
src=src,
|
230 |
+
src_padding_mask=src_padding_mask,
|
231 |
+
max_len=max_len_config, # Use config max_len as generation limit
|
232 |
+
sos_idx=sos_idx,
|
233 |
+
eos_idx=eos_idx,
|
234 |
+
device=device,
|
235 |
+
) # Returns a single tensor [1, generated_len]
|
236 |
+
|
237 |
+
except (RuntimeError, AttributeError, ImportError, Exception) as e:
|
238 |
+
logging.error(f"Error during greedy_decode call: {e}", exc_info=True)
|
239 |
+
return f"[Decoding Error: {e}]"
|
240 |
+
|
241 |
|
242 |
# --- Decode Generated Tokens ---
|
243 |
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
|
244 |
+
# Check if the source itself was just padding or EOS
|
245 |
+
if len(src_ids) <= 2 and all(t in [pad_idx, eos_idx, sos_idx] for t in src_ids): # Rough check
|
246 |
+
logging.warning(f"Input SMILES '{src_sentence}' resulted in very short/empty encoding, leading to empty decode.")
|
247 |
+
return "[Decoding Warning: Input potentially too short or invalid after tokenization]"
|
248 |
+
else:
|
249 |
+
logging.warning(
|
250 |
+
f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
|
251 |
+
)
|
252 |
+
return "[Decoding Error: Empty Output]"
|
253 |
|
254 |
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
255 |
try:
|
256 |
# Decode using the target tokenizer, skipping special tokens
|
257 |
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
258 |
+
return translation.strip() # Strip leading/trailing whitespace
|
259 |
except Exception as e:
|
260 |
logging.error(
|
261 |
f"Error decoding target tokens {tgt_tokens}: {e}",
|
262 |
exc_info=True,
|
263 |
)
|
264 |
+
return "[Decoding Error: Tokenizer failed]"
|
265 |
|
266 |
|
267 |
+
# --- Model/Tokenizer Loading Function ---
|
268 |
def load_model_and_tokenizers():
|
269 |
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
270 |
+
global model, smiles_tokenizer, iupac_tokenizer, device, config, SmilesIupacLitModule, generate_square_subsequent_mask
|
271 |
if model is not None: # Already loaded
|
272 |
logging.info("Model and tokenizers already loaded.")
|
273 |
return
|
274 |
|
275 |
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
|
276 |
+
|
277 |
+
# --- Check if helper code loaded ---
|
278 |
+
if SmilesIupacLitModule is None or generate_square_subsequent_mask is None:
|
279 |
+
error_msg = f"Initialization Error: Could not load required components from enhanced_trainer.py. Check Space logs and ensure the file exists in the repo root."
|
280 |
+
logging.error(error_msg)
|
281 |
+
raise gr.Error(error_msg)
|
282 |
+
|
283 |
try:
|
284 |
# Determine device
|
285 |
if torch.cuda.is_available():
|
286 |
+
device = torch.device("cuda")
|
287 |
+
logging.info("CUDA available, using GPU.")
|
|
|
|
|
|
|
|
|
288 |
else:
|
289 |
device = torch.device("cpu")
|
290 |
logging.info("CUDA not available, using CPU.")
|
|
|
292 |
# Download files
|
293 |
logging.info("Downloading files from Hugging Face Hub...")
|
294 |
try:
|
295 |
+
# Define cache directory, default to './hf_cache' if GRADIO_CACHE is not set
|
296 |
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
|
297 |
+
os.makedirs(cache_dir, exist_ok=True) # Ensure cache dir exists
|
298 |
logging.info(f"Using cache directory: {cache_dir}")
|
299 |
|
300 |
+
# Download files to the specified cache directory
|
301 |
checkpoint_path = hf_hub_download(
|
302 |
+
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir, force_download=False # Avoid re-download if files exist
|
303 |
)
|
304 |
smiles_tokenizer_path = hf_hub_download(
|
305 |
+
repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir, force_download=False
|
|
|
|
|
306 |
)
|
307 |
iupac_tokenizer_path = hf_hub_download(
|
308 |
+
repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir, force_download=False
|
|
|
|
|
309 |
)
|
310 |
config_path = hf_hub_download(
|
311 |
+
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir, force_download=False
|
312 |
)
|
313 |
+
logging.info("Files downloaded (or found in cache) successfully.")
|
314 |
except Exception as e:
|
315 |
logging.error(
|
316 |
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
|
|
|
328 |
logging.info("Configuration loaded.")
|
329 |
# --- Validate essential config keys ---
|
330 |
required_keys = [
|
331 |
+
"src_vocab_size",
|
332 |
+
"tgt_vocab_size",
|
333 |
"emb_size",
|
334 |
"nhead",
|
335 |
"ffn_hid_dim",
|
336 |
"num_encoder_layers",
|
337 |
"num_decoder_layers",
|
338 |
"dropout",
|
339 |
+
"max_len", # Crucial for tokenization and generation limit
|
340 |
"pad_token_id",
|
341 |
"bos_token_id",
|
342 |
"eos_token_id",
|
343 |
]
|
344 |
+
|
345 |
+
# Check for alternative key names if needed (adjust if your config uses different names)
|
346 |
+
config['src_vocab_size'] = config.get('src_vocab_size', config.get('SRC_VOCAB_SIZE'))
|
347 |
+
config['tgt_vocab_size'] = config.get('tgt_vocab_size', config.get('TGT_VOCAB_SIZE'))
|
348 |
+
config['emb_size'] = config.get('emb_size', config.get('EMB_SIZE'))
|
349 |
+
config['nhead'] = config.get('nhead', config.get('NHEAD'))
|
350 |
+
config['ffn_hid_dim'] = config.get('ffn_hid_dim', config.get('FFN_HID_DIM'))
|
351 |
+
config['num_encoder_layers'] = config.get('num_encoder_layers', config.get('NUM_ENCODER_LAYERS'))
|
352 |
+
config['num_decoder_layers'] = config.get('num_decoder_layers', config.get('NUM_DECODER_LAYERS'))
|
353 |
+
config['dropout'] = config.get('dropout', config.get('DROPOUT'))
|
354 |
+
config['max_len'] = config.get('max_len', config.get('MAX_LEN'))
|
355 |
+
config['pad_token_id'] = config.get('pad_token_id', config.get('PAD_IDX'))
|
356 |
+
config['bos_token_id'] = config.get('bos_token_id', config.get('BOS_IDX'))
|
357 |
+
config['eos_token_id'] = config.get('eos_token_id', config.get('EOS_IDX'))
|
358 |
+
# Add UNK if needed by your model/tokenizer setup
|
359 |
+
# config['unk_token_id'] = config.get('unk_token_id', config.get('UNK_IDX', 0)) # Default to 0 if missing? Risky.
|
360 |
|
361 |
missing_keys = [key for key in required_keys if config.get(key) is None]
|
362 |
if missing_keys:
|
|
|
366 |
)
|
367 |
|
368 |
logging.info(
|
369 |
+
f"Using config: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
|
370 |
f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
|
371 |
f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
|
372 |
)
|
|
|
391 |
try:
|
392 |
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
|
393 |
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
394 |
+
|
395 |
+
# --- Validate Tokenizer Special Tokens Against Config ---
|
|
|
396 |
pad_token = "<pad>"
|
397 |
sos_token = "<sos>"
|
398 |
eos_token = "<eos>"
|
399 |
+
unk_token = "<unk>" # Assuming standard UNK token
|
400 |
issues = []
|
401 |
+
|
402 |
+
# SMILES Tokenizer Checks
|
403 |
+
smiles_pad_id = smiles_tokenizer.token_to_id(pad_token)
|
404 |
+
smiles_unk_id = smiles_tokenizer.token_to_id(unk_token)
|
405 |
+
if smiles_pad_id is None or smiles_pad_id != config["pad_token_id"]:
|
406 |
+
issues.append(f"SMILES PAD ID mismatch (Tokenizer: {smiles_pad_id}, Config: {config['pad_token_id']})")
|
407 |
+
if smiles_unk_id is None:
|
408 |
+
issues.append("SMILES UNK token not found in tokenizer")
|
409 |
+
|
410 |
+
# IUPAC Tokenizer Checks
|
411 |
+
iupac_pad_id = iupac_tokenizer.token_to_id(pad_token)
|
412 |
+
iupac_sos_id = iupac_tokenizer.token_to_id(sos_token)
|
413 |
+
iupac_eos_id = iupac_tokenizer.token_to_id(eos_token)
|
414 |
+
iupac_unk_id = iupac_tokenizer.token_to_id(unk_token)
|
415 |
+
if iupac_pad_id is None or iupac_pad_id != config["pad_token_id"]:
|
416 |
+
issues.append(f"IUPAC PAD ID mismatch (Tokenizer: {iupac_pad_id}, Config: {config['pad_token_id']})")
|
417 |
+
if iupac_sos_id is None or iupac_sos_id != config["bos_token_id"]:
|
418 |
+
issues.append(f"IUPAC SOS ID mismatch (Tokenizer: {iupac_sos_id}, Config: {config['bos_token_id']})")
|
419 |
+
if iupac_eos_id is None or iupac_eos_id != config["eos_token_id"]:
|
420 |
+
issues.append(f"IUPAC EOS ID mismatch (Tokenizer: {iupac_eos_id}, Config: {config['eos_token_id']})")
|
421 |
+
if iupac_unk_id is None:
|
422 |
+
issues.append("IUPAC UNK token not found in tokenizer")
|
423 |
+
|
424 |
if issues:
|
425 |
+
logging.warning("Tokenizer validation issues detected: \n - " + "\n - ".join(issues))
|
426 |
+
# Decide if this is critical. For inference, SOS/EOS/PAD matches are most important.
|
427 |
+
# raise gr.Error("Tokenizer Validation Error: Mismatch between config and tokenizer files. Check logs.")
|
428 |
+
else:
|
429 |
+
logging.info("Tokenizers loaded and special tokens validated against config.")
|
430 |
|
431 |
except Exception as e:
|
432 |
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
|
433 |
raise gr.Error(
|
434 |
+
f"Tokenizer Error: Could not load tokenizers. Check logs and file paths. Error: {e}"
|
435 |
)
|
436 |
|
437 |
# Load model
|
438 |
logging.info("Loading model from checkpoint...")
|
439 |
try:
|
440 |
+
# Instantiate the LightningModule using hyperparameters from config
|
441 |
+
# Make sure SmilesIupacLitModule's __init__ accepts these keys
|
442 |
+
model_instance = SmilesIupacLitModule(**config)
|
443 |
+
|
444 |
+
# Load the state dict from the checkpoint onto the instance
|
445 |
+
# Use load_state_dict for more control if load_from_checkpoint causes issues
|
446 |
+
# state_dict = torch.load(checkpoint_path, map_location=device)['state_dict']
|
447 |
+
# model_instance.load_state_dict(state_dict, strict=True) # Try strict=False if needed
|
448 |
+
|
449 |
+
# Use load_from_checkpoint (simpler if it works)
|
450 |
model = SmilesIupacLitModule.load_from_checkpoint(
|
451 |
checkpoint_path,
|
|
|
452 |
map_location=device,
|
453 |
+
# Pass hparams again ONLY if they are needed by load_from_checkpoint specifically
|
454 |
+
# and not just by __init__. Usually, instantiating first is cleaner.
|
455 |
+
**config, # Try removing this if you instantiate first
|
456 |
+
strict=True # Start strict, set to False ONLY if necessary and you understand why
|
457 |
)
|
458 |
|
459 |
+
|
460 |
model.to(device)
|
461 |
model.eval()
|
462 |
+
model.freeze() # Freeze weights for inference
|
463 |
logging.info(
|
464 |
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
|
465 |
)
|
|
|
467 |
except FileNotFoundError:
|
468 |
logging.error(f"Checkpoint file not found: {checkpoint_path}")
|
469 |
raise gr.Error(
|
470 |
+
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found at expected path."
|
471 |
)
|
472 |
+
except RuntimeError as e: # Catch specific runtime errors like size mismatch, OOM
|
473 |
+
logging.error(f"Runtime error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
|
474 |
+
if "size mismatch" in str(e):
|
475 |
+
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint structure. Or config doesn't match model definition."
|
476 |
+
logging.error(error_detail)
|
477 |
+
raise gr.Error(f"Model Load Error: {error_detail} Original error: {e}")
|
478 |
+
elif "CUDA out of memory" in str(e) or "memory" in str(e).lower():
|
479 |
+
logging.warning("Potential OOM error during model loading.")
|
480 |
+
gc.collect()
|
481 |
+
if device.type == "cuda": torch.cuda.empty_cache()
|
482 |
+
raise gr.Error(f"Model Load Error: OOM loading model. Check Space resources. Error: {e}")
|
483 |
+
else:
|
484 |
+
raise gr.Error(f"Model Load Error: Runtime error. Check logs. Error: {e}")
|
485 |
+
except Exception as e: # Catch other potential errors during loading
|
486 |
logging.error(
|
487 |
+
f"Unexpected error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
|
488 |
+
)
|
489 |
+
raise gr.Error(
|
490 |
+
f"Model Load Error: Failed to load checkpoint for unknown reason. Check logs. Error: {e}"
|
491 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
|
493 |
+
except gr.Error as ge: # Catch Gradio-specific errors raised above
|
494 |
+
raise ge # Re-raise to stop app launch correctly
|
495 |
+
except Exception as e: # Catch any other unexpected errors during the whole process
|
496 |
+
logging.error(f"Unexpected error during loading process: {e}", exc_info=True)
|
497 |
raise gr.Error(
|
498 |
+
f"Initialization Error: Unexpected failure during setup. Check logs. Error: {e}"
|
499 |
)
|
500 |
|
501 |
|
502 |
+
# --- Inference Function for Gradio ---
|
503 |
def predict_iupac(smiles_string):
|
504 |
"""
|
505 |
Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
|
506 |
+
Handles input validation, canonicalization, translation, and output formatting.
|
507 |
"""
|
|
|
|
|
|
|
|
|
|
|
508 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
509 |
|
510 |
+
# --- Check Initialization ---
|
511 |
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
512 |
+
error_msg = "Error: Model or tokenizers not loaded. App initialization failed. Check Space logs."
|
513 |
logging.error(error_msg)
|
514 |
+
# Return the error directly in the output box
|
515 |
+
return error_msg
|
516 |
|
517 |
+
# --- Input Validation ---
|
518 |
if not smiles_string or not smiles_string.strip():
|
519 |
+
return "Error: Please enter a valid SMILES string."
|
|
|
520 |
|
521 |
+
smiles_input_raw = smiles_string.strip()
|
522 |
|
523 |
+
# --- Canonicalize SMILES (Optional but Recommended) ---
|
524 |
+
smiles_input_canon = smiles_input_raw
|
525 |
+
if Chem:
|
526 |
+
try:
|
527 |
+
mol = Chem.MolFromSmiles(smiles_input_raw)
|
528 |
+
if mol:
|
529 |
+
smiles_input_canon = Chem.MolToSmiles(mol, canonical=True)
|
530 |
+
logging.info(f"Canonicalized SMILES: {smiles_input_raw} -> {smiles_input_canon}")
|
531 |
+
else:
|
532 |
+
# RDKit couldn't parse it, proceed with raw input but warn
|
533 |
+
logging.warning(f"Could not parse SMILES '{smiles_input_raw}' with RDKit. Using raw input.")
|
534 |
+
# Optionally return an error here if strict parsing is needed
|
535 |
+
# return f"Error: Invalid SMILES string '{smiles_input_raw}' according to RDKit."
|
536 |
+
except Exception as e:
|
537 |
+
logging.error(f"Error during RDKit canonicalization for '{smiles_input_raw}': {e}", exc_info=True)
|
538 |
+
# Proceed with raw input, maybe add note to output
|
539 |
+
# return f"Error: RDKit processing failed: {e}" # Option to fail hard
|
540 |
+
|
541 |
+
# --- Translation ---
|
542 |
try:
|
|
|
543 |
sos_idx = config["bos_token_id"]
|
544 |
eos_idx = config["eos_token_id"]
|
545 |
pad_idx = config["pad_token_id"]
|
546 |
+
gen_max_len = config["max_len"] # Use max_len from config
|
547 |
|
548 |
+
predicted_name = translate(
|
549 |
model=model,
|
550 |
+
src_sentence=smiles_input_canon, # Use canonicalized SMILES
|
551 |
smiles_tokenizer=smiles_tokenizer,
|
552 |
iupac_tokenizer=iupac_tokenizer,
|
553 |
device=device,
|
554 |
+
max_len_config=gen_max_len,
|
555 |
sos_idx=sos_idx,
|
556 |
eos_idx=eos_idx,
|
557 |
pad_idx=pad_idx,
|
558 |
)
|
559 |
+
logging.info(f"SMILES: '{smiles_input_canon}', Prediction: '{predicted_name}'")
|
560 |
|
561 |
# --- Format Output ---
|
562 |
+
# Check if translate returned an error message
|
563 |
+
if predicted_name.startswith("[") and predicted_name.endswith("]"):
|
564 |
+
# Assume it's an error/warning message from translate()
|
565 |
output_text = (
|
566 |
+
f"Input SMILES: {smiles_input_canon}\n"
|
567 |
+
f"(Raw Input: {smiles_input_raw})\n\n" # Show raw if canonicalization happened
|
568 |
+
f"Prediction Failed: {predicted_name}"
|
569 |
+
)
|
570 |
+
elif not predicted_name: # Handle empty string case
|
571 |
+
output_text = (
|
572 |
+
f"Input SMILES: {smiles_input_canon}\n"
|
573 |
+
f"(Raw Input: {smiles_input_raw})\n\n"
|
574 |
+
f"Prediction: [No name generated]"
|
575 |
)
|
|
|
|
|
576 |
else:
|
577 |
output_text = (
|
578 |
+
f"Input SMILES: {smiles_input_canon}\n"
|
579 |
+
f"(Raw Input: {smiles_input_raw})\n\n"
|
580 |
f"Predicted IUPAC Name (Greedy Decode):\n"
|
581 |
f"{predicted_name}"
|
582 |
)
|
583 |
+
# Remove the "(Raw Input...)" line if canonicalization didn't change the input
|
584 |
+
if smiles_input_raw == smiles_input_canon:
|
585 |
+
output_text = output_text.replace(f"(Raw Input: {smiles_input_raw})\n", "")
|
586 |
|
587 |
+
return output_text.strip()
|
|
|
|
|
588 |
|
589 |
except Exception as e:
|
590 |
+
# Catch-all for unexpected errors during the prediction process
|
591 |
+
logging.error(f"Unexpected error during prediction for '{smiles_input_canon}': {e}", exc_info=True)
|
592 |
+
error_msg = f"Error: An unexpected error occurred during translation: {e}"
|
593 |
+
return error_msg
|
594 |
|
595 |
|
596 |
# --- Load Model on App Start ---
|
597 |
+
# Wrap in try/except to allow Gradio UI to potentially display an error
|
598 |
+
model_load_error = None
|
599 |
try:
|
600 |
load_model_and_tokenizers()
|
601 |
except gr.Error as ge:
|
602 |
+
logging.error(f"Gradio Initialization Error during load: {ge}")
|
603 |
+
model_load_error = str(ge) # Store error message
|
604 |
except Exception as e:
|
605 |
+
logging.critical(f"CRITICAL error during initial model loading: {e}", exc_info=True)
|
606 |
+
model_load_error = f"Critical Error: {e}. Check Space logs."
|
607 |
|
608 |
|
609 |
+
# --- Create Gradio Interface ---
|
610 |
title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
|
611 |
description = f"""
|
612 |
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
|
613 |
Translation uses **greedy decoding** (picks the most likely next word at each step).
|
|
|
614 |
"""
|
615 |
+
if model_load_error:
|
616 |
+
description += f"\n\n**WARNING: Failed to load model or components.**\nReason: {model_load_error}\nFunctionality will be limited."
|
617 |
+
elif device:
|
618 |
+
description += f"\n**Note:** Model loaded on **{str(device).upper()}**. Check `config.json` in the repo for model details."
|
619 |
+
else:
|
620 |
+
description += f"\n**Note:** Device information unavailable (loading might have failed)."
|
621 |
+
|
622 |
+
# Use gr.Blocks for layout
|
623 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
|
624 |
gr.Markdown(f"# {title}")
|
625 |
gr.Markdown(description)
|
626 |
+
|
627 |
with gr.Row():
|
628 |
+
with gr.Column(scale=1): # Input column takes less space
|
629 |
smiles_input = gr.Textbox(
|
630 |
label="SMILES String",
|
631 |
+
placeholder="Enter SMILES string (e.g., CCO or c1ccccc1)",
|
632 |
+
lines=2, # Slightly more lines for longer SMILES
|
633 |
)
|
634 |
+
submit_btn = gr.Button("Translate", variant="primary")
|
635 |
+
|
636 |
+
with gr.Column(scale=2): # Output column takes more space
|
|
|
|
|
637 |
output_text = gr.Textbox(
|
638 |
+
label="Result",
|
639 |
+
lines=5, # More lines for formatted output
|
640 |
show_copy_button=True,
|
641 |
+
interactive=False, # Output box is not for user input
|
642 |
)
|
643 |
+
|
644 |
+
# Define examples
|
645 |
+
gr.Examples(
|
646 |
+
examples=[
|
647 |
+
"CCO",
|
648 |
+
"C1=CC=C(C=C1)C(=O)O", # Benzoic acid
|
649 |
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", # Ibuprofen
|
650 |
+
"INVALID_SMILES",
|
651 |
+
"ClC(Cl)(Cl)C1=CC=C(C=C1)C(C2=CC=C(Cl)C=C2)C(Cl)(Cl)Cl", # DDT
|
652 |
+
],
|
653 |
+
inputs=smiles_input,
|
654 |
+
outputs=output_text,
|
655 |
+
fn=predict_iupac, # Function to run for examples
|
656 |
+
cache_examples=False, # Re-run examples each time if needed, True might speed up demo loading
|
657 |
+
)
|
658 |
+
|
659 |
+
# Connect the button click and input change events
|
660 |
+
submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text, api_name="translate_smiles")
|
661 |
+
# Optionally trigger on text change (can be slow/resource intensive)
|
662 |
+
# smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
|
663 |
|
664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
# --- Launch the App ---
|
666 |
if __name__ == "__main__":
|
667 |
+
# Set share=True to get a public link (useful for testing)
|
668 |
+
# Set debug=True for more detailed Gradio errors during development
|
669 |
+
iface.launch(share=False, debug=False)
|