Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -1,23 +1,27 @@
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
import pytorch_lightning as pl
|
6 |
import os
|
7 |
import json
|
8 |
import logging
|
9 |
from tokenizers import Tokenizer
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
-
import gc
|
12 |
-
import math
|
13 |
|
14 |
# --- Configuration ---
|
15 |
# Ensure these match the files uploaded to your Hugging Face Hub repository
|
16 |
-
MODEL_REPO_ID =
|
17 |
-
|
|
|
|
|
18 |
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
|
19 |
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
|
20 |
-
CONFIG_FILENAME =
|
|
|
|
|
21 |
# --- End Configuration ---
|
22 |
|
23 |
# --- Logging ---
|
@@ -30,6 +34,7 @@ try:
|
|
30 |
# We need the LightningModule definition and the mask function
|
31 |
# Ensure enhanced_trainer.py is present in the root of your HF Repo
|
32 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
|
|
33 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
34 |
|
35 |
# REMOVED: Redundant import from test_ckpt as functions are defined below
|
@@ -59,6 +64,7 @@ iupac_tokenizer: Tokenizer | None = None
|
|
59 |
device: torch.device | None = None
|
60 |
config: dict | None = None
|
61 |
|
|
|
62 |
# --- Beam Search Decoding Logic (Locally defined) ---
|
63 |
def beam_search_decode(
|
64 |
model: pl.LightningModule,
|
@@ -77,8 +83,8 @@ def beam_search_decode(
|
|
77 |
Performs beam search decoding using the LightningModule's model.
|
78 |
(Ensures this code is self-contained within app.py or correctly imported)
|
79 |
"""
|
80 |
-
model.eval()
|
81 |
-
transformer_model = model.model
|
82 |
n_best = min(n_best, beam_width)
|
83 |
|
84 |
try:
|
@@ -86,15 +92,15 @@ def beam_search_decode(
|
|
86 |
# --- Encode Source ---
|
87 |
memory = transformer_model.encode(
|
88 |
src, src_padding_mask
|
89 |
-
)
|
90 |
memory = memory.to(device)
|
91 |
-
memory_key_padding_mask = src_padding_mask.to(memory.device)
|
92 |
|
93 |
# --- Initialize Beams ---
|
94 |
initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
95 |
sos_idx
|
96 |
-
)
|
97 |
-
initial_beam_score = torch.zeros(1, dtype=torch.float, device=device)
|
98 |
active_beams = [(initial_beam_seq, initial_beam_score)]
|
99 |
finished_beams = []
|
100 |
|
@@ -107,20 +113,20 @@ def beam_search_decode(
|
|
107 |
for current_seq, current_score in active_beams:
|
108 |
# Check if the beam already ended
|
109 |
if current_seq[0, -1].item() == eos_idx:
|
110 |
-
|
111 |
finished_beams.append((current_seq, current_score))
|
112 |
continue
|
113 |
|
114 |
# Prepare inputs for the decoder
|
115 |
-
tgt_input = current_seq
|
116 |
tgt_seq_len = tgt_input.shape[1]
|
117 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
118 |
device
|
119 |
-
)
|
120 |
# No padding in target during generation yet
|
121 |
tgt_padding_mask = torch.zeros(
|
122 |
tgt_input.shape, dtype=torch.bool, device=device
|
123 |
-
)
|
124 |
|
125 |
# Decode one step
|
126 |
decoder_output = transformer_model.decode(
|
@@ -129,35 +135,41 @@ def beam_search_decode(
|
|
129 |
tgt_mask=tgt_mask,
|
130 |
tgt_padding_mask=tgt_padding_mask,
|
131 |
memory_key_padding_mask=memory_key_padding_mask,
|
132 |
-
)
|
133 |
|
134 |
# Get logits for the *next* token prediction
|
135 |
next_token_logits = transformer_model.generator(
|
136 |
-
decoder_output[
|
137 |
-
|
|
|
|
|
138 |
|
139 |
# Calculate log probabilities and add current beam score
|
140 |
log_probs = F.log_softmax(
|
141 |
next_token_logits, dim=-1
|
142 |
-
)
|
143 |
-
combined_scores =
|
|
|
|
|
144 |
|
145 |
# Find top k candidates for the *next* step
|
146 |
topk_log_probs, topk_indices = torch.topk(
|
147 |
combined_scores, beam_width, dim=-1
|
148 |
-
)
|
149 |
|
150 |
# Expand potential beams
|
151 |
for i in range(beam_width):
|
152 |
next_token_id = topk_indices[0, i].item()
|
153 |
# Score is the cumulative log probability of the new sequence
|
154 |
-
next_score = topk_log_probs[0, i].reshape(
|
|
|
|
|
155 |
next_token_tensor = torch.tensor(
|
156 |
[[next_token_id]], dtype=torch.long, device=device
|
157 |
-
)
|
158 |
new_seq = torch.cat(
|
159 |
[current_seq, next_token_tensor], dim=1
|
160 |
-
)
|
161 |
potential_next_beams.append((new_seq, next_score))
|
162 |
|
163 |
# --- Prune Beams ---
|
@@ -166,26 +178,30 @@ def beam_search_decode(
|
|
166 |
|
167 |
# Select the top `beam_width` beams for the next iteration
|
168 |
active_beams = []
|
169 |
-
temp_finished_beams = []
|
170 |
for seq, score in potential_next_beams:
|
171 |
-
if
|
172 |
-
|
|
|
|
|
|
|
173 |
|
174 |
is_finished = seq[0, -1].item() == eos_idx
|
175 |
if is_finished:
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
elif len(active_beams) < beam_width:
|
180 |
-
|
181 |
-
|
182 |
|
183 |
# Add the newly finished beams to the main finished list
|
184 |
finished_beams.extend(temp_finished_beams)
|
185 |
# Optional: Prune finished_beams if it grows too large (e.g., keep top 2*beam_width)
|
186 |
finished_beams.sort(key=lambda x: x[1].item(), reverse=True)
|
187 |
-
finished_beams = finished_beams[
|
188 |
-
|
|
|
189 |
|
190 |
# --- Final Selection ---
|
191 |
# Add any remaining active beams (which didn't finish) to the finished list
|
@@ -200,29 +216,36 @@ def beam_search_decode(
|
|
200 |
return score.item()
|
201 |
else:
|
202 |
# Length penalty calculation
|
203 |
-
penalty = (
|
|
|
|
|
204 |
return score.item() / penalty
|
205 |
# Alternative simpler penalty:
|
206 |
# return score.item() / (float(seq_len) ** length_penalty)
|
207 |
|
208 |
-
finished_beams.sort(
|
|
|
|
|
209 |
|
210 |
# Return the top n_best sequences (excluding the initial SOS token)
|
211 |
top_sequences = [
|
212 |
-
seq[:, 1:]
|
213 |
-
|
|
|
|
|
214 |
return top_sequences
|
215 |
|
216 |
except RuntimeError as e:
|
217 |
logging.error(f"Runtime error during beam search decode: {e}", exc_info=True)
|
218 |
-
if "CUDA out of memory" in str(e) and device.type ==
|
219 |
gc.collect()
|
220 |
torch.cuda.empty_cache()
|
221 |
-
return []
|
222 |
except Exception as e:
|
223 |
logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
|
224 |
return []
|
225 |
|
|
|
226 |
# --- Translation Function (Locally defined) ---
|
227 |
def translate(
|
228 |
model: pl.LightningModule,
|
@@ -242,9 +265,9 @@ def translate(
|
|
242 |
Translates a single SMILES string using beam search.
|
243 |
(Ensures this code is self-contained within app.py or correctly imported)
|
244 |
"""
|
245 |
-
model.eval()
|
246 |
translations = []
|
247 |
-
n_best = min(n_best, beam_width)
|
248 |
|
249 |
# --- Tokenize Source ---
|
250 |
try:
|
@@ -264,19 +287,21 @@ def translate(
|
|
264 |
# --- Prepare Input Tensor and Mask ---
|
265 |
src = (
|
266 |
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
|
267 |
-
)
|
268 |
# Create padding mask (True where it's a pad token, should be all False here)
|
269 |
-
src_padding_mask = (src == pad_idx).to(device)
|
270 |
|
271 |
# --- Perform Beam Search Decoding ---
|
272 |
# Calls the beam_search_decode function defined *above in this file*
|
273 |
# Note: max_len for generation should come from config if it dictates output length
|
274 |
-
generation_max_len = config.get(
|
|
|
|
|
275 |
tgt_tokens_list = beam_search_decode(
|
276 |
model=model,
|
277 |
src=src,
|
278 |
src_padding_mask=src_padding_mask,
|
279 |
-
max_len=generation_max_len,
|
280 |
sos_idx=sos_idx,
|
281 |
eos_idx=eos_idx,
|
282 |
pad_idx=pad_idx,
|
@@ -284,7 +309,7 @@ def translate(
|
|
284 |
beam_width=beam_width,
|
285 |
n_best=n_best,
|
286 |
length_penalty=length_penalty,
|
287 |
-
)
|
288 |
|
289 |
# --- Decode Generated Tokens ---
|
290 |
if not tgt_tokens_list:
|
@@ -302,10 +327,15 @@ def translate(
|
|
302 |
)
|
303 |
translations.append(translation)
|
304 |
except Exception as e:
|
305 |
-
logging.error(
|
|
|
|
|
|
|
306 |
translations.append("[Decoding Error]")
|
307 |
else:
|
308 |
-
logging.warning(
|
|
|
|
|
309 |
translations.append("[Decoding Error - Empty Tensor]")
|
310 |
|
311 |
# Pad with error messages if fewer than n_best results were generated
|
@@ -314,11 +344,12 @@ def translate(
|
|
314 |
|
315 |
return translations
|
316 |
|
|
|
317 |
# --- Model/Tokenizer Loading Function ---
|
318 |
def load_model_and_tokenizers():
|
319 |
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
320 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
321 |
-
if model is not None:
|
322 |
logging.info("Model and tokenizers already loaded.")
|
323 |
return
|
324 |
|
@@ -327,21 +358,24 @@ def load_model_and_tokenizers():
|
|
327 |
# Determine device - Use CPU for Gradio Spaces unless GPU is explicitly available and desired
|
328 |
# For simplicity and broader compatibility on free tier Spaces, CPU is safer.
|
329 |
if torch.cuda.is_available():
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
|
|
|
|
335 |
else:
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
|
340 |
# Download files from HF Hub
|
341 |
logging.info("Downloading files from Hugging Face Hub...")
|
342 |
try:
|
343 |
# Use cache directory for Spaces persistence if possible
|
344 |
-
cache_dir = os.environ.get(
|
|
|
|
|
345 |
os.makedirs(cache_dir, exist_ok=True)
|
346 |
logging.info(f"Using cache directory: {cache_dir}")
|
347 |
|
@@ -349,10 +383,14 @@ def load_model_and_tokenizers():
|
|
349 |
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
|
350 |
)
|
351 |
smiles_tokenizer_path = hf_hub_download(
|
352 |
-
repo_id=MODEL_REPO_ID,
|
|
|
|
|
353 |
)
|
354 |
iupac_tokenizer_path = hf_hub_download(
|
355 |
-
repo_id=MODEL_REPO_ID,
|
|
|
|
|
356 |
)
|
357 |
config_path = hf_hub_download(
|
358 |
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
|
@@ -379,8 +417,8 @@ def load_model_and_tokenizers():
|
|
379 |
# Mappings might be needed if keys in config.json differ from these exact names
|
380 |
required_keys = [
|
381 |
# Need vocab sizes used during *training* for loading
|
382 |
-
"actual_src_vocab_size",
|
383 |
-
"actual_tgt_vocab_size",
|
384 |
# Model architecture params
|
385 |
"emb_size",
|
386 |
"nhead",
|
@@ -388,17 +426,21 @@ def load_model_and_tokenizers():
|
|
388 |
"num_encoder_layers",
|
389 |
"num_decoder_layers",
|
390 |
"dropout",
|
391 |
-
"max_len",
|
392 |
# Special token IDs needed for generation
|
393 |
# Assuming standard names, adjust if your config uses different keys
|
394 |
-
"pad_token_id",
|
395 |
-
"bos_token_id",
|
396 |
-
"eos_token_id",
|
397 |
]
|
398 |
# Remap keys if necessary (e.g., if config.json uses 'src_vocab_size' instead of 'actual_src_vocab_size')
|
399 |
config_key_mapping = {
|
400 |
-
"actual_src_vocab_size": config.get(
|
401 |
-
|
|
|
|
|
|
|
|
|
402 |
"emb_size": config.get("emb_size"),
|
403 |
"nhead": config.get("nhead"),
|
404 |
"ffn_hid_dim": config.get("ffn_hid_dim"),
|
@@ -406,9 +448,15 @@ def load_model_and_tokenizers():
|
|
406 |
"num_decoder_layers": config.get("num_decoder_layers"),
|
407 |
"dropout": config.get("dropout"),
|
408 |
"max_len": config.get("max_len"),
|
409 |
-
"pad_token_id": config.get(
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
}
|
413 |
# Update config with potentially remapped values
|
414 |
config.update(config_key_mapping)
|
@@ -420,31 +468,44 @@ def load_model_and_tokenizers():
|
|
420 |
# Re-check missing keys after attempting defaults
|
421 |
missing_keys = [key for key in required_keys if config.get(key) is None]
|
422 |
if missing_keys:
|
423 |
-
|
424 |
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
|
425 |
f"Ensure these were saved in the hyperparameters during training."
|
426 |
-
|
427 |
else:
|
428 |
-
|
|
|
|
|
429 |
|
430 |
# Log the final config values being used
|
431 |
-
logging.info(
|
432 |
-
|
433 |
-
|
|
|
|
|
434 |
|
435 |
except FileNotFoundError:
|
436 |
-
logging.error(
|
437 |
-
|
|
|
|
|
|
|
|
|
438 |
except json.JSONDecodeError as e:
|
439 |
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
|
440 |
-
raise gr.Error(
|
441 |
-
|
|
|
|
|
442 |
logging.error(f"Config validation error: {e}")
|
443 |
raise gr.Error(f"Config Error: {e}")
|
444 |
-
except Exception as e:
|
445 |
-
|
446 |
-
|
447 |
-
|
|
|
|
|
|
|
448 |
|
449 |
# Load tokenizers
|
450 |
logging.info("Loading tokenizers...")
|
@@ -453,7 +514,7 @@ def load_model_and_tokenizers():
|
|
453 |
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
454 |
logging.info("Tokenizers loaded.")
|
455 |
|
456 |
-
|
457 |
pad_token = "<pad>"
|
458 |
sos_token = "<sos>"
|
459 |
eos_token = "<eos>"
|
@@ -461,23 +522,33 @@ def load_model_and_tokenizers():
|
|
461 |
|
462 |
issues = []
|
463 |
if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
464 |
-
issues.append(
|
|
|
|
|
465 |
if smiles_tokenizer.token_to_id(unk_token) is None:
|
466 |
issues.append("SMILES UNK token not found")
|
467 |
|
468 |
if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
469 |
-
issues.append(
|
|
|
|
|
470 |
if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
|
471 |
-
issues.append(
|
|
|
|
|
472 |
if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
|
473 |
-
|
|
|
|
|
474 |
if iupac_tokenizer.token_to_id(unk_token) is None:
|
475 |
issues.append("IUPAC UNK token not found")
|
476 |
|
477 |
if issues:
|
478 |
-
|
479 |
-
|
480 |
-
|
|
|
|
|
481 |
|
482 |
except Exception as e:
|
483 |
logging.error(
|
@@ -499,46 +570,62 @@ def load_model_and_tokenizers():
|
|
499 |
# Ensure these keys exist in your loaded 'config' dict after validation/mapping
|
500 |
src_vocab_size=config["actual_src_vocab_size"],
|
501 |
tgt_vocab_size=config["actual_tgt_vocab_size"],
|
502 |
-
hparams_dict=config,
|
503 |
-
map_location=device,
|
504 |
-
strict=False,
|
505 |
# REMOVED invalid argument: device="cpu",
|
506 |
)
|
507 |
|
508 |
# Ensure model is on the correct device, in eval mode, and frozen
|
509 |
model.to(device)
|
510 |
model.eval()
|
511 |
-
model.freeze()
|
512 |
logging.info(
|
513 |
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
|
514 |
)
|
515 |
|
516 |
except FileNotFoundError:
|
517 |
-
logging.error(
|
518 |
-
|
|
|
|
|
|
|
|
|
519 |
except Exception as e:
|
520 |
logging.error(
|
521 |
-
f"Error loading model from checkpoint {checkpoint_path}: {e}",
|
|
|
522 |
)
|
523 |
# Check for common errors
|
524 |
if "size mismatch" in str(e):
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
|
|
|
|
529 |
elif "memory" in str(e).lower():
|
530 |
logging.warning("Potential Out-of-Memory error during model loading.")
|
531 |
gc.collect()
|
532 |
-
if device.type ==
|
533 |
-
|
|
|
|
|
|
|
534 |
else:
|
535 |
-
raise gr.Error(
|
|
|
|
|
536 |
|
537 |
-
except gr.Error:
|
538 |
raise
|
539 |
-
except Exception as e:
|
540 |
-
logging.error(
|
541 |
-
|
|
|
|
|
|
|
|
|
542 |
|
543 |
|
544 |
# --- Inference Function for Gradio ---
|
@@ -553,14 +640,18 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
|
|
553 |
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
|
554 |
logging.error(error_msg)
|
555 |
# Try to determine n_best for error output formatting
|
556 |
-
try:
|
557 |
-
|
|
|
|
|
558 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
|
559 |
|
560 |
if not smiles_string or not smiles_string.strip():
|
561 |
error_msg = "Error: Please enter a valid SMILES string."
|
562 |
-
try:
|
563 |
-
|
|
|
|
|
564 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
|
565 |
|
566 |
smiles_input = smiles_string.strip()
|
@@ -571,10 +662,14 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
|
|
571 |
n_best = int(n_best_str)
|
572 |
length_penalty = float(length_penalty_str)
|
573 |
if beam_width < 1 or n_best < 1 or n_best > beam_width:
|
574 |
-
|
|
|
|
|
575 |
if length_penalty < 0:
|
576 |
-
|
577 |
-
|
|
|
|
|
578 |
except ValueError as e:
|
579 |
error_msg = f"Error: Invalid input parameter ({e}). Please check beam width, n_best, and length penalty values."
|
580 |
logging.error(error_msg)
|
@@ -588,10 +683,10 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
|
|
588 |
try:
|
589 |
# --- Call the core translation logic ---
|
590 |
# Retrieve necessary IDs from the loaded config
|
591 |
-
sos_idx = config[
|
592 |
-
eos_idx = config[
|
593 |
-
pad_idx = config[
|
594 |
-
gen_max_len = config[
|
595 |
|
596 |
predicted_names = translate(
|
597 |
model=model,
|
@@ -599,7 +694,7 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
|
|
599 |
smiles_tokenizer=smiles_tokenizer,
|
600 |
iupac_tokenizer=iupac_tokenizer,
|
601 |
device=device,
|
602 |
-
max_len=gen_max_len,
|
603 |
sos_idx=sos_idx,
|
604 |
eos_idx=eos_idx,
|
605 |
pad_idx=pad_idx,
|
@@ -615,15 +710,16 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
|
|
615 |
else:
|
616 |
# Ensure we only display up to n_best results, even if translate returned more/fewer due to errors
|
617 |
display_names = predicted_names[:n_best]
|
618 |
-
output_text = (
|
619 |
-
|
|
|
|
|
620 |
output_text += "\n".join(
|
621 |
[f"{i + 1}. {name}" for i, name in enumerate(display_names)]
|
622 |
)
|
623 |
# Add a note if fewer results than requested were generated
|
624 |
if len(display_names) < n_best:
|
625 |
-
|
626 |
-
|
627 |
|
628 |
return output_text
|
629 |
|
@@ -632,7 +728,8 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
|
|
632 |
error_msg = f"Runtime Error during translation: {e}"
|
633 |
if "memory" in str(e).lower():
|
634 |
gc.collect()
|
635 |
-
if device.type ==
|
|
|
636 |
error_msg += " (Potential OOM - try reducing beam width or input length)"
|
637 |
# Return n_best error messages
|
638 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
|
@@ -649,13 +746,15 @@ def predict_iupac(smiles_string, beam_width_str, n_best_str, length_penalty_str)
|
|
649 |
try:
|
650 |
load_model_and_tokenizers()
|
651 |
except gr.Error as ge:
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
except Exception as e:
|
657 |
# Catch any non-Gradio errors during the initial load sequence
|
658 |
-
logging.error(
|
|
|
|
|
659 |
# Optionally raise gr.Error here too, although it might be too late if Gradio hasn't fully initialized.
|
660 |
# raise gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
|
661 |
|
@@ -670,13 +769,13 @@ Translation uses beam search decoding. Adjust parameters below.
|
|
670 |
|
671 |
# Define examples using the input types expected by the interface
|
672 |
examples = [
|
673 |
-
["CCO", 5, 3, 0.6],
|
674 |
-
["C1=CC=CC=C1", 5, 3, 0.6],
|
675 |
-
["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6],
|
676 |
-
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6],
|
677 |
# Very complex example - might take time or fail on CPU/low memory
|
678 |
# ["CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=C(C(=N4)C5=CC=CC=C5)C", 8, 1, 0.7], # Gleevec (Imatinib) - simplified SMILES structure
|
679 |
-
["INVALID_SMILES", 3, 1, 0.6],
|
680 |
]
|
681 |
|
682 |
# Ensure input components match the `predict_iupac` function signature order and types
|
@@ -687,16 +786,28 @@ smiles_input = gr.Textbox(
|
|
687 |
)
|
688 |
# Use number inputs for sliders if direct type casting is desired, but sliders often return float/int anyway
|
689 |
beam_width_input = gr.Slider(
|
690 |
-
minimum=1,
|
691 |
-
|
|
|
|
|
|
|
|
|
692 |
)
|
693 |
n_best_input = gr.Slider(
|
694 |
-
minimum=1,
|
695 |
-
|
|
|
|
|
|
|
|
|
696 |
)
|
697 |
length_penalty_input = gr.Slider(
|
698 |
-
minimum=0.0,
|
699 |
-
|
|
|
|
|
|
|
|
|
700 |
)
|
701 |
output_text = gr.Textbox(
|
702 |
label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
|
@@ -704,19 +815,19 @@ output_text = gr.Textbox(
|
|
704 |
|
705 |
# Create the interface instance
|
706 |
iface = gr.Interface(
|
707 |
-
fn=predict_iupac,
|
708 |
-
inputs=[
|
709 |
smiles_input,
|
710 |
beam_width_input,
|
711 |
n_best_input,
|
712 |
-
length_penalty_input
|
713 |
],
|
714 |
-
outputs=output_text,
|
715 |
title=title,
|
716 |
description=description,
|
717 |
-
examples=examples,
|
718 |
-
allow_flagging="never",
|
719 |
-
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
|
720 |
article="""
|
721 |
**Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
|
722 |
Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations.
|
@@ -733,4 +844,4 @@ if __name__ == "__main__":
|
|
733 |
# Set share=False or remove for deployment on Spaces.
|
734 |
# Use server_name="0.0.0.0" to make it accessible on the network if running locally
|
735 |
# Use auth=("username", "password") for basic authentication
|
736 |
-
iface.launch()
|
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
+
import torch.nn.functional as F # Needed for beam search log_softmax
|
5 |
+
import pytorch_lightning as pl # Needed for LightningModule and loading
|
6 |
import os
|
7 |
import json
|
8 |
import logging
|
9 |
from tokenizers import Tokenizer
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
+
import gc # For garbage collection on potential OOM
|
12 |
+
import math # Potentially needed by imported classes
|
13 |
|
14 |
# --- Configuration ---
|
15 |
# Ensure these match the files uploaded to your Hugging Face Hub repository
|
16 |
+
MODEL_REPO_ID = (
|
17 |
+
"AdrianM0/smiles-to-iupac-translator" # <-- Make sure this is your repo ID
|
18 |
+
)
|
19 |
+
CHECKPOINT_FILENAME = "last.ckpt" # Or "best_model.ckpt" or whatever you uploaded
|
20 |
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
|
21 |
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
|
22 |
+
CONFIG_FILENAME = (
|
23 |
+
"config.json" # Assumes you saved hparams to config.json during/after training
|
24 |
+
)
|
25 |
# --- End Configuration ---
|
26 |
|
27 |
# --- Logging ---
|
|
|
34 |
# We need the LightningModule definition and the mask function
|
35 |
# Ensure enhanced_trainer.py is present in the root of your HF Repo
|
36 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
37 |
+
|
38 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
39 |
|
40 |
# REMOVED: Redundant import from test_ckpt as functions are defined below
|
|
|
64 |
device: torch.device | None = None
|
65 |
config: dict | None = None
|
66 |
|
67 |
+
|
68 |
# --- Beam Search Decoding Logic (Locally defined) ---
|
69 |
def beam_search_decode(
|
70 |
model: pl.LightningModule,
|
|
|
83 |
Performs beam search decoding using the LightningModule's model.
|
84 |
(Ensures this code is self-contained within app.py or correctly imported)
|
85 |
"""
|
86 |
+
model.eval() # Ensure model is in evaluation mode
|
87 |
+
transformer_model = model.model # Access the underlying Seq2SeqTransformer
|
88 |
n_best = min(n_best, beam_width)
|
89 |
|
90 |
try:
|
|
|
92 |
# --- Encode Source ---
|
93 |
memory = transformer_model.encode(
|
94 |
src, src_padding_mask
|
95 |
+
) # [1, src_len, emb_size]
|
96 |
memory = memory.to(device)
|
97 |
+
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
|
98 |
|
99 |
# --- Initialize Beams ---
|
100 |
initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
101 |
sos_idx
|
102 |
+
) # [1, 1]
|
103 |
+
initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
|
104 |
active_beams = [(initial_beam_seq, initial_beam_score)]
|
105 |
finished_beams = []
|
106 |
|
|
|
113 |
for current_seq, current_score in active_beams:
|
114 |
# Check if the beam already ended
|
115 |
if current_seq[0, -1].item() == eos_idx:
|
116 |
+
# If already finished, add directly to finished beams and skip expansion
|
117 |
finished_beams.append((current_seq, current_score))
|
118 |
continue
|
119 |
|
120 |
# Prepare inputs for the decoder
|
121 |
+
tgt_input = current_seq # [1, current_len]
|
122 |
tgt_seq_len = tgt_input.shape[1]
|
123 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
124 |
device
|
125 |
+
) # [curr_len, curr_len]
|
126 |
# No padding in target during generation yet
|
127 |
tgt_padding_mask = torch.zeros(
|
128 |
tgt_input.shape, dtype=torch.bool, device=device
|
129 |
+
) # [1, curr_len]
|
130 |
|
131 |
# Decode one step
|
132 |
decoder_output = transformer_model.decode(
|
|
|
135 |
tgt_mask=tgt_mask,
|
136 |
tgt_padding_mask=tgt_padding_mask,
|
137 |
memory_key_padding_mask=memory_key_padding_mask,
|
138 |
+
) # [1, curr_len, emb_size]
|
139 |
|
140 |
# Get logits for the *next* token prediction
|
141 |
next_token_logits = transformer_model.generator(
|
142 |
+
decoder_output[
|
143 |
+
:, -1, :
|
144 |
+
] # Use output corresponding to the last input token
|
145 |
+
) # [1, tgt_vocab_size]
|
146 |
|
147 |
# Calculate log probabilities and add current beam score
|
148 |
log_probs = F.log_softmax(
|
149 |
next_token_logits, dim=-1
|
150 |
+
) # [1, tgt_vocab_size]
|
151 |
+
combined_scores = (
|
152 |
+
log_probs + current_score
|
153 |
+
) # Add score of the current path
|
154 |
|
155 |
# Find top k candidates for the *next* step
|
156 |
topk_log_probs, topk_indices = torch.topk(
|
157 |
combined_scores, beam_width, dim=-1
|
158 |
+
) # [1, beam_width], [1, beam_width]
|
159 |
|
160 |
# Expand potential beams
|
161 |
for i in range(beam_width):
|
162 |
next_token_id = topk_indices[0, i].item()
|
163 |
# Score is the cumulative log probability of the new sequence
|
164 |
+
next_score = topk_log_probs[0, i].reshape(
|
165 |
+
1
|
166 |
+
) # Keep as tensor [1]
|
167 |
next_token_tensor = torch.tensor(
|
168 |
[[next_token_id]], dtype=torch.long, device=device
|
169 |
+
) # [1, 1]
|
170 |
new_seq = torch.cat(
|
171 |
[current_seq, next_token_tensor], dim=1
|
172 |
+
) # [1, current_len + 1]
|
173 |
potential_next_beams.append((new_seq, next_score))
|
174 |
|
175 |
# --- Prune Beams ---
|
|
|
178 |
|
179 |
# Select the top `beam_width` beams for the next iteration
|
180 |
active_beams = []
|
181 |
+
temp_finished_beams = [] # Collect beams finished in *this* step
|
182 |
for seq, score in potential_next_beams:
|
183 |
+
if (
|
184 |
+
len(active_beams) >= beam_width
|
185 |
+
and len(temp_finished_beams) >= beam_width
|
186 |
+
):
|
187 |
+
break # Optimization: Stop if we have enough active and finished candidates
|
188 |
|
189 |
is_finished = seq[0, -1].item() == eos_idx
|
190 |
if is_finished:
|
191 |
+
# Add to temporary finished list for this step
|
192 |
+
if len(temp_finished_beams) < beam_width:
|
193 |
+
temp_finished_beams.append((seq, score))
|
194 |
elif len(active_beams) < beam_width:
|
195 |
+
# Add to active beams for next step
|
196 |
+
active_beams.append((seq, score))
|
197 |
|
198 |
# Add the newly finished beams to the main finished list
|
199 |
finished_beams.extend(temp_finished_beams)
|
200 |
# Optional: Prune finished_beams if it grows too large (e.g., keep top 2*beam_width)
|
201 |
finished_beams.sort(key=lambda x: x[1].item(), reverse=True)
|
202 |
+
finished_beams = finished_beams[
|
203 |
+
: beam_width * 2
|
204 |
+
] # Keep a reasonable number
|
205 |
|
206 |
# --- Final Selection ---
|
207 |
# Add any remaining active beams (which didn't finish) to the finished list
|
|
|
216 |
return score.item()
|
217 |
else:
|
218 |
# Length penalty calculation
|
219 |
+
penalty = (
|
220 |
+
(5.0 + float(seq_len)) / 6.0
|
221 |
+
) ** length_penalty # Common formula
|
222 |
return score.item() / penalty
|
223 |
# Alternative simpler penalty:
|
224 |
# return score.item() / (float(seq_len) ** length_penalty)
|
225 |
|
226 |
+
finished_beams.sort(
|
227 |
+
key=get_score_with_penalty, reverse=True
|
228 |
+
) # Higher score is better
|
229 |
|
230 |
# Return the top n_best sequences (excluding the initial SOS token)
|
231 |
top_sequences = [
|
232 |
+
seq[:, 1:]
|
233 |
+
for seq, score in finished_beams[:n_best]
|
234 |
+
if seq.shape[1] > 1 # Ensure seq not just SOS
|
235 |
+
] # seq shape [1, len] -> [1, len-1]
|
236 |
return top_sequences
|
237 |
|
238 |
except RuntimeError as e:
|
239 |
logging.error(f"Runtime error during beam search decode: {e}", exc_info=True)
|
240 |
+
if "CUDA out of memory" in str(e) and device.type == "cuda":
|
241 |
gc.collect()
|
242 |
torch.cuda.empty_cache()
|
243 |
+
return [] # Return empty list on error
|
244 |
except Exception as e:
|
245 |
logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
|
246 |
return []
|
247 |
|
248 |
+
|
249 |
# --- Translation Function (Locally defined) ---
|
250 |
def translate(
|
251 |
model: pl.LightningModule,
|
|
|
265 |
Translates a single SMILES string using beam search.
|
266 |
(Ensures this code is self-contained within app.py or correctly imported)
|
267 |
"""
|
268 |
+
model.eval() # Ensure model is in eval mode
|
269 |
translations = []
|
270 |
+
n_best = min(n_best, beam_width) # Can't return more than beam width
|
271 |
|
272 |
# --- Tokenize Source ---
|
273 |
try:
|
|
|
287 |
# --- Prepare Input Tensor and Mask ---
|
288 |
src = (
|
289 |
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
|
290 |
+
) # [1, src_len]
|
291 |
# Create padding mask (True where it's a pad token, should be all False here)
|
292 |
+
src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
|
293 |
|
294 |
# --- Perform Beam Search Decoding ---
|
295 |
# Calls the beam_search_decode function defined *above in this file*
|
296 |
# Note: max_len for generation should come from config if it dictates output length
|
297 |
+
generation_max_len = config.get(
|
298 |
+
"max_len", 256
|
299 |
+
) # Use config max_len for output limit
|
300 |
tgt_tokens_list = beam_search_decode(
|
301 |
model=model,
|
302 |
src=src,
|
303 |
src_padding_mask=src_padding_mask,
|
304 |
+
max_len=generation_max_len, # Use generation limit
|
305 |
sos_idx=sos_idx,
|
306 |
eos_idx=eos_idx,
|
307 |
pad_idx=pad_idx,
|
|
|
309 |
beam_width=beam_width,
|
310 |
n_best=n_best,
|
311 |
length_penalty=length_penalty,
|
312 |
+
) # Returns list of tensors
|
313 |
|
314 |
# --- Decode Generated Tokens ---
|
315 |
if not tgt_tokens_list:
|
|
|
327 |
)
|
328 |
translations.append(translation)
|
329 |
except Exception as e:
|
330 |
+
logging.error(
|
331 |
+
f"Error decoding target tokens {tgt_tokens} for beam {i}: {e}",
|
332 |
+
exc_info=True,
|
333 |
+
)
|
334 |
translations.append("[Decoding Error]")
|
335 |
else:
|
336 |
+
logging.warning(
|
337 |
+
f"Beam {i} result was empty or None for SMILES: {src_sentence}"
|
338 |
+
)
|
339 |
translations.append("[Decoding Error - Empty Tensor]")
|
340 |
|
341 |
# Pad with error messages if fewer than n_best results were generated
|
|
|
344 |
|
345 |
return translations
|
346 |
|
347 |
+
|
348 |
# --- Model/Tokenizer Loading Function ---
|
349 |
def load_model_and_tokenizers():
|
350 |
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
351 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
352 |
+
if model is not None: # Already loaded
|
353 |
logging.info("Model and tokenizers already loaded.")
|
354 |
return
|
355 |
|
|
|
358 |
# Determine device - Use CPU for Gradio Spaces unless GPU is explicitly available and desired
|
359 |
# For simplicity and broader compatibility on free tier Spaces, CPU is safer.
|
360 |
if torch.cuda.is_available():
|
361 |
+
logging.warning(
|
362 |
+
"CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
|
363 |
+
)
|
364 |
+
device = torch.device("cpu")
|
365 |
+
# Uncomment below and comment above line to try using GPU if available
|
366 |
+
# device = torch.device("cuda")
|
367 |
+
# logging.info("CUDA available, using GPU.")
|
368 |
else:
|
369 |
+
device = torch.device("cpu")
|
370 |
+
logging.info("CUDA not available, using CPU.")
|
|
|
371 |
|
372 |
# Download files from HF Hub
|
373 |
logging.info("Downloading files from Hugging Face Hub...")
|
374 |
try:
|
375 |
# Use cache directory for Spaces persistence if possible
|
376 |
+
cache_dir = os.environ.get(
|
377 |
+
"GRADIO_CACHE", "./hf_cache"
|
378 |
+
) # Gradio sets cache dir
|
379 |
os.makedirs(cache_dir, exist_ok=True)
|
380 |
logging.info(f"Using cache directory: {cache_dir}")
|
381 |
|
|
|
383 |
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
|
384 |
)
|
385 |
smiles_tokenizer_path = hf_hub_download(
|
386 |
+
repo_id=MODEL_REPO_ID,
|
387 |
+
filename=SMILES_TOKENIZER_FILENAME,
|
388 |
+
cache_dir=cache_dir,
|
389 |
)
|
390 |
iupac_tokenizer_path = hf_hub_download(
|
391 |
+
repo_id=MODEL_REPO_ID,
|
392 |
+
filename=IUPAC_TOKENIZER_FILENAME,
|
393 |
+
cache_dir=cache_dir,
|
394 |
)
|
395 |
config_path = hf_hub_download(
|
396 |
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
|
|
|
417 |
# Mappings might be needed if keys in config.json differ from these exact names
|
418 |
required_keys = [
|
419 |
# Need vocab sizes used during *training* for loading
|
420 |
+
"actual_src_vocab_size", # Assuming this was saved in hparams
|
421 |
+
"actual_tgt_vocab_size", # Assuming this was saved in hparams
|
422 |
# Model architecture params
|
423 |
"emb_size",
|
424 |
"nhead",
|
|
|
426 |
"num_encoder_layers",
|
427 |
"num_decoder_layers",
|
428 |
"dropout",
|
429 |
+
"max_len", # Needed for generation limit and tokenizer setting
|
430 |
# Special token IDs needed for generation
|
431 |
# Assuming standard names, adjust if your config uses different keys
|
432 |
+
"pad_token_id", # Often 0
|
433 |
+
"bos_token_id", # Often 1 (used as SOS)
|
434 |
+
"eos_token_id", # Often 2
|
435 |
]
|
436 |
# Remap keys if necessary (e.g., if config.json uses 'src_vocab_size' instead of 'actual_src_vocab_size')
|
437 |
config_key_mapping = {
|
438 |
+
"actual_src_vocab_size": config.get(
|
439 |
+
"actual_src_vocab_size", config.get("src_vocab_size")
|
440 |
+
),
|
441 |
+
"actual_tgt_vocab_size": config.get(
|
442 |
+
"actual_tgt_vocab_size", config.get("tgt_vocab_size")
|
443 |
+
),
|
444 |
"emb_size": config.get("emb_size"),
|
445 |
"nhead": config.get("nhead"),
|
446 |
"ffn_hid_dim": config.get("ffn_hid_dim"),
|
|
|
448 |
"num_decoder_layers": config.get("num_decoder_layers"),
|
449 |
"dropout": config.get("dropout"),
|
450 |
"max_len": config.get("max_len"),
|
451 |
+
"pad_token_id": config.get(
|
452 |
+
"pad_token_id"
|
453 |
+
), # Use default if missing? Risky.
|
454 |
+
"bos_token_id": config.get(
|
455 |
+
"bos_token_id"
|
456 |
+
), # Use default if missing? Risky.
|
457 |
+
"eos_token_id": config.get(
|
458 |
+
"eos_token_id"
|
459 |
+
), # Use default if missing? Risky.
|
460 |
}
|
461 |
# Update config with potentially remapped values
|
462 |
config.update(config_key_mapping)
|
|
|
468 |
# Re-check missing keys after attempting defaults
|
469 |
missing_keys = [key for key in required_keys if config.get(key) is None]
|
470 |
if missing_keys:
|
471 |
+
raise ValueError(
|
472 |
f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}. "
|
473 |
f"Ensure these were saved in the hyperparameters during training."
|
474 |
+
)
|
475 |
else:
|
476 |
+
logging.warning(
|
477 |
+
f"Config file was missing keys, used defaults for: {defaults_used}. This might be incorrect!"
|
478 |
+
)
|
479 |
|
480 |
# Log the final config values being used
|
481 |
+
logging.info(
|
482 |
+
f"Using config values: src_vocab={config['actual_src_vocab_size']}, tgt_vocab={config['actual_tgt_vocab_size']}, "
|
483 |
+
f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
|
484 |
+
f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
|
485 |
+
)
|
486 |
|
487 |
except FileNotFoundError:
|
488 |
+
logging.error(
|
489 |
+
f"Config file not found locally after download attempt: {config_path}"
|
490 |
+
)
|
491 |
+
raise gr.Error(
|
492 |
+
f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo."
|
493 |
+
)
|
494 |
except json.JSONDecodeError as e:
|
495 |
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
|
496 |
+
raise gr.Error(
|
497 |
+
f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}"
|
498 |
+
)
|
499 |
+
except ValueError as e: # Catch our custom validation error
|
500 |
logging.error(f"Config validation error: {e}")
|
501 |
raise gr.Error(f"Config Error: {e}")
|
502 |
+
except Exception as e: # Catch other potential errors during config processing
|
503 |
+
logging.error(
|
504 |
+
f"Unexpected error loading or validating config: {e}", exc_info=True
|
505 |
+
)
|
506 |
+
raise gr.Error(
|
507 |
+
f"Config Error: Unexpected error processing config. Check logs. Error: {e}"
|
508 |
+
)
|
509 |
|
510 |
# Load tokenizers
|
511 |
logging.info("Loading tokenizers...")
|
|
|
514 |
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
515 |
logging.info("Tokenizers loaded.")
|
516 |
|
517 |
+
# --- Validate Tokenizer Special Tokens Against Config ---
|
518 |
pad_token = "<pad>"
|
519 |
sos_token = "<sos>"
|
520 |
eos_token = "<eos>"
|
|
|
522 |
|
523 |
issues = []
|
524 |
if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
525 |
+
issues.append(
|
526 |
+
f"SMILES PAD ID mismatch (tokenizer={smiles_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})"
|
527 |
+
)
|
528 |
if smiles_tokenizer.token_to_id(unk_token) is None:
|
529 |
issues.append("SMILES UNK token not found")
|
530 |
|
531 |
if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
532 |
+
issues.append(
|
533 |
+
f"IUPAC PAD ID mismatch (tokenizer={iupac_tokenizer.token_to_id(pad_token)}, config={config['pad_token_id']})"
|
534 |
+
)
|
535 |
if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
|
536 |
+
issues.append(
|
537 |
+
f"IUPAC SOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(sos_token)}, config={config['bos_token_id']})"
|
538 |
+
)
|
539 |
if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
|
540 |
+
issues.append(
|
541 |
+
f"IUPAC EOS ID mismatch (tokenizer={iupac_tokenizer.token_to_id(eos_token)}, config={config['eos_token_id']})"
|
542 |
+
)
|
543 |
if iupac_tokenizer.token_to_id(unk_token) is None:
|
544 |
issues.append("IUPAC UNK token not found")
|
545 |
|
546 |
if issues:
|
547 |
+
logging.warning(
|
548 |
+
"Tokenizer validation issues detected: " + "; ".join(issues)
|
549 |
+
)
|
550 |
+
# Decide if this is fatal or just a warning
|
551 |
+
# raise gr.Error("Tokenizer Error: Special token IDs mismatch config. Check tokenizers and config.json.") # Make it fatal if IDs must match
|
552 |
|
553 |
except Exception as e:
|
554 |
logging.error(
|
|
|
570 |
# Ensure these keys exist in your loaded 'config' dict after validation/mapping
|
571 |
src_vocab_size=config["actual_src_vocab_size"],
|
572 |
tgt_vocab_size=config["actual_tgt_vocab_size"],
|
573 |
+
hparams_dict=config, # Pass the loaded config as hparams
|
574 |
+
map_location=device, # Map model to the chosen device (CPU or CUDA)
|
575 |
+
strict=False, # Be less strict about matching keys, useful for PTL versions or minor changes
|
576 |
# REMOVED invalid argument: device="cpu",
|
577 |
)
|
578 |
|
579 |
# Ensure model is on the correct device, in eval mode, and frozen
|
580 |
model.to(device)
|
581 |
model.eval()
|
582 |
+
model.freeze() # Disables gradient calculations
|
583 |
logging.info(
|
584 |
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
|
585 |
)
|
586 |
|
587 |
except FileNotFoundError:
|
588 |
+
logging.error(
|
589 |
+
f"Checkpoint file not found locally after download attempt: {checkpoint_path}"
|
590 |
+
)
|
591 |
+
raise gr.Error(
|
592 |
+
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
|
593 |
+
)
|
594 |
except Exception as e:
|
595 |
logging.error(
|
596 |
+
f"Error loading model from checkpoint {checkpoint_path}: {e}",
|
597 |
+
exc_info=True,
|
598 |
)
|
599 |
# Check for common errors
|
600 |
if "size mismatch" in str(e):
|
601 |
+
error_detail = (
|
602 |
+
f"Potential size mismatch. Check if vocab sizes in config.json ({config.get('actual_src_vocab_size')}, "
|
603 |
+
f"{config.get('actual_tgt_vocab_size')}) match the loaded checkpoint's embedding layers."
|
604 |
+
)
|
605 |
+
logging.error(error_detail)
|
606 |
+
raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
|
607 |
elif "memory" in str(e).lower():
|
608 |
logging.warning("Potential Out-of-Memory error during model loading.")
|
609 |
gc.collect()
|
610 |
+
if device.type == "cuda":
|
611 |
+
torch.cuda.empty_cache()
|
612 |
+
raise gr.Error(
|
613 |
+
f"Model Error: Out of memory loading model. Check Space resources. Error: {e}"
|
614 |
+
)
|
615 |
else:
|
616 |
+
raise gr.Error(
|
617 |
+
f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}"
|
618 |
+
)
|
619 |
|
620 |
+
except gr.Error: # Re-raise Gradio errors to be displayed
|
621 |
raise
|
622 |
+
except Exception as e: # Catch any other unexpected errors
|
623 |
+
logging.error(
|
624 |
+
f"Unexpected error during model/tokenizer loading: {e}", exc_info=True
|
625 |
+
)
|
626 |
+
raise gr.Error(
|
627 |
+
f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}"
|
628 |
+
)
|
629 |
|
630 |
|
631 |
# --- Inference Function for Gradio ---
|
|
|
640 |
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
|
641 |
logging.error(error_msg)
|
642 |
# Try to determine n_best for error output formatting
|
643 |
+
try:
|
644 |
+
n_best_int = int(n_best_str)
|
645 |
+
except:
|
646 |
+
n_best_int = 1
|
647 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
|
648 |
|
649 |
if not smiles_string or not smiles_string.strip():
|
650 |
error_msg = "Error: Please enter a valid SMILES string."
|
651 |
+
try:
|
652 |
+
n_best_int = int(n_best_str)
|
653 |
+
except:
|
654 |
+
n_best_int = 1
|
655 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best_int)])
|
656 |
|
657 |
smiles_input = smiles_string.strip()
|
|
|
662 |
n_best = int(n_best_str)
|
663 |
length_penalty = float(length_penalty_str)
|
664 |
if beam_width < 1 or n_best < 1 or n_best > beam_width:
|
665 |
+
raise ValueError(
|
666 |
+
"Beam width and n_best must be >= 1, and n_best <= beam width."
|
667 |
+
)
|
668 |
if length_penalty < 0:
|
669 |
+
logging.warning(
|
670 |
+
f"Length penalty {length_penalty} is negative, using 0.0 instead."
|
671 |
+
)
|
672 |
+
length_penalty = 0.0
|
673 |
except ValueError as e:
|
674 |
error_msg = f"Error: Invalid input parameter ({e}). Please check beam width, n_best, and length penalty values."
|
675 |
logging.error(error_msg)
|
|
|
683 |
try:
|
684 |
# --- Call the core translation logic ---
|
685 |
# Retrieve necessary IDs from the loaded config
|
686 |
+
sos_idx = config["bos_token_id"]
|
687 |
+
eos_idx = config["eos_token_id"]
|
688 |
+
pad_idx = config["pad_token_id"]
|
689 |
+
gen_max_len = config["max_len"] # Max length for generation
|
690 |
|
691 |
predicted_names = translate(
|
692 |
model=model,
|
|
|
694 |
smiles_tokenizer=smiles_tokenizer,
|
695 |
iupac_tokenizer=iupac_tokenizer,
|
696 |
device=device,
|
697 |
+
max_len=gen_max_len, # Pass generation length limit
|
698 |
sos_idx=sos_idx,
|
699 |
eos_idx=eos_idx,
|
700 |
pad_idx=pad_idx,
|
|
|
710 |
else:
|
711 |
# Ensure we only display up to n_best results, even if translate returned more/fewer due to errors
|
712 |
display_names = predicted_names[:n_best]
|
713 |
+
output_text = (
|
714 |
+
f"Input SMILES: {smiles_input}\n\n"
|
715 |
+
f"Top {len(display_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
|
716 |
+
)
|
717 |
output_text += "\n".join(
|
718 |
[f"{i + 1}. {name}" for i, name in enumerate(display_names)]
|
719 |
)
|
720 |
# Add a note if fewer results than requested were generated
|
721 |
if len(display_names) < n_best:
|
722 |
+
output_text += f"\n\nNote: Only {len(display_names)} result(s) generated successfully."
|
|
|
723 |
|
724 |
return output_text
|
725 |
|
|
|
728 |
error_msg = f"Runtime Error during translation: {e}"
|
729 |
if "memory" in str(e).lower():
|
730 |
gc.collect()
|
731 |
+
if device.type == "cuda":
|
732 |
+
torch.cuda.empty_cache()
|
733 |
error_msg += " (Potential OOM - try reducing beam width or input length)"
|
734 |
# Return n_best error messages
|
735 |
return "\n".join([f"{i + 1}. {error_msg}" for i in range(n_best)])
|
|
|
746 |
try:
|
747 |
load_model_and_tokenizers()
|
748 |
except gr.Error as ge:
|
749 |
+
logging.error(f"Gradio Initialization Error: {ge}")
|
750 |
+
# Gradio handles displaying gr.Error, but we log it too.
|
751 |
+
# We might want to display a placeholder UI or message if loading fails critically.
|
752 |
+
pass # Allow Gradio to potentially start with an error message
|
753 |
except Exception as e:
|
754 |
# Catch any non-Gradio errors during the initial load sequence
|
755 |
+
logging.error(
|
756 |
+
f"Critical error during initial model loading sequence: {e}", exc_info=True
|
757 |
+
)
|
758 |
# Optionally raise gr.Error here too, although it might be too late if Gradio hasn't fully initialized.
|
759 |
# raise gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
|
760 |
|
|
|
769 |
|
770 |
# Define examples using the input types expected by the interface
|
771 |
examples = [
|
772 |
+
["CCO", 5, 3, 0.6], # Ethanol
|
773 |
+
["C1=CC=CC=C1", 5, 3, 0.6], # Benzene
|
774 |
+
["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
|
775 |
+
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
|
776 |
# Very complex example - might take time or fail on CPU/low memory
|
777 |
# ["CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C)NC4=NC=C(C(=N4)C5=CC=CC=C5)C", 8, 1, 0.7], # Gleevec (Imatinib) - simplified SMILES structure
|
778 |
+
["INVALID_SMILES", 3, 1, 0.6], # Example of invalid input
|
779 |
]
|
780 |
|
781 |
# Ensure input components match the `predict_iupac` function signature order and types
|
|
|
786 |
)
|
787 |
# Use number inputs for sliders if direct type casting is desired, but sliders often return float/int anyway
|
788 |
beam_width_input = gr.Slider(
|
789 |
+
minimum=1,
|
790 |
+
maximum=10,
|
791 |
+
value=5,
|
792 |
+
step=1,
|
793 |
+
label="Beam Width (k)",
|
794 |
+
info="Number of sequences kept at each step (higher = more exploration, slower). Affects memory usage.",
|
795 |
)
|
796 |
n_best_input = gr.Slider(
|
797 |
+
minimum=1,
|
798 |
+
maximum=10,
|
799 |
+
value=3,
|
800 |
+
step=1,
|
801 |
+
label="Number of Results (n_best)",
|
802 |
+
info="How many top sequences to return (must be <= Beam Width).",
|
803 |
)
|
804 |
length_penalty_input = gr.Slider(
|
805 |
+
minimum=0.0,
|
806 |
+
maximum=2.0,
|
807 |
+
value=0.6,
|
808 |
+
step=0.1,
|
809 |
+
label="Length Penalty (alpha)",
|
810 |
+
info="Controls preference for sequence length. >1 favors longer, <1 favors shorter, 0 no penalty.",
|
811 |
)
|
812 |
output_text = gr.Textbox(
|
813 |
label="Predicted IUPAC Name(s)", lines=5, show_copy_button=True
|
|
|
815 |
|
816 |
# Create the interface instance
|
817 |
iface = gr.Interface(
|
818 |
+
fn=predict_iupac, # The function to call
|
819 |
+
inputs=[ # List of input components
|
820 |
smiles_input,
|
821 |
beam_width_input,
|
822 |
n_best_input,
|
823 |
+
length_penalty_input,
|
824 |
],
|
825 |
+
outputs=output_text, # Output component
|
826 |
title=title,
|
827 |
description=description,
|
828 |
+
examples=examples, # Examples to populate the interface
|
829 |
+
allow_flagging="never", # Disable flagging
|
830 |
+
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"), # Optional theme
|
831 |
article="""
|
832 |
**Limitations:** Translation quality depends heavily on the model size, training data, and the complexity of the SMILES input.
|
833 |
Very long or unusual SMILES strings may result in errors, timeouts, or inaccurate translations.
|
|
|
844 |
# Set share=False or remove for deployment on Spaces.
|
845 |
# Use server_name="0.0.0.0" to make it accessible on the network if running locally
|
846 |
# Use auth=("username", "password") for basic authentication
|
847 |
+
iface.launch() # share=True is deprecated, use launch()
|