Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
import torch
|
|
|
4 |
# import torch.nn.functional as F # No longer needed for greedy decode directly
|
5 |
import pytorch_lightning as pl
|
6 |
import os
|
@@ -9,13 +10,8 @@ import logging
|
|
9 |
from tokenizers import Tokenizer
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
import gc
|
12 |
-
|
13 |
-
|
14 |
-
from rdkit import RDLogger # Optional: To suppress RDKit logs
|
15 |
-
RDLogger.DisableLog('rdApp.*') # Suppress RDKit warnings/errors if desired
|
16 |
-
except ImportError:
|
17 |
-
logging.warning("RDKit not found. SMILES canonicalization will be skipped. Install with 'pip install rdkit'")
|
18 |
-
Chem = None # Set Chem to None if RDKit is not available
|
19 |
|
20 |
# --- Configuration ---
|
21 |
MODEL_REPO_ID = (
|
@@ -34,24 +30,24 @@ logging.basicConfig(
|
|
34 |
|
35 |
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
|
36 |
try:
|
37 |
-
# Ensure enhanced_trainer.py is in the root directory of your space repo
|
38 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
|
|
39 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
40 |
except ImportError as e:
|
41 |
logging.error(
|
42 |
f"Failed to import helper code from enhanced_trainer.py: {e}. "
|
43 |
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
|
44 |
)
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
except Exception as e:
|
49 |
logging.error(
|
50 |
f"An unexpected error occurred during helper code import: {e}", exc_info=True
|
51 |
)
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
# --- Global Variables (Load Model Once) ---
|
57 |
model: pl.LightningModule | None = None
|
@@ -73,107 +69,87 @@ def greedy_decode(
|
|
73 |
) -> torch.Tensor:
|
74 |
"""
|
75 |
Performs greedy decoding using the LightningModule's model.
|
76 |
-
Assumes model has 'model.encode', 'model.decode', 'model.generator' attributes.
|
77 |
"""
|
78 |
-
if not hasattr(model, 'model') or not hasattr(model.model, 'encode') or \
|
79 |
-
not hasattr(model.model, 'decode') or not hasattr(model.model, 'generator'):
|
80 |
-
logging.error("Model object does not have the expected 'model.encode/decode/generator' structure.")
|
81 |
-
raise AttributeError("Model structure mismatch for greedy decoding.")
|
82 |
-
|
83 |
-
if generate_square_subsequent_mask is None:
|
84 |
-
logging.error("generate_square_subsequent_mask function not imported.")
|
85 |
-
raise ImportError("generate_square_subsequent_mask is required for greedy_decode.")
|
86 |
-
|
87 |
-
|
88 |
model.eval() # Ensure model is in evaluation mode
|
89 |
transformer_model = model.model # Access the underlying Seq2SeqTransformer
|
90 |
|
91 |
try:
|
92 |
with torch.no_grad():
|
93 |
# --- Encode Source ---
|
94 |
-
# The mask should be True where the input *is* padding.
|
95 |
memory = transformer_model.encode(
|
96 |
-
src,
|
97 |
-
)
|
98 |
-
# If batch_first=False (default): [src_len, 1, emb_size] -> adjust usage below
|
99 |
-
# Assuming batch_first=False for standard nn.Transformer
|
100 |
memory = memory.to(device)
|
101 |
-
|
102 |
-
# Memory key padding mask needs to be [batch_size, src_len] -> [1, src_len]
|
103 |
-
memory_key_padding_mask = src_padding_mask.to(device) # [1, src_len]
|
104 |
|
105 |
# --- Initialize Target Sequence ---
|
106 |
-
# Start with the SOS token
|
107 |
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
108 |
sos_idx
|
109 |
-
)
|
110 |
|
111 |
# --- Decoding Loop ---
|
112 |
-
for
|
113 |
-
# Target processing depends on whether decoder expects batch_first
|
114 |
-
# Standard nn.TransformerDecoder expects [tgt_len, batch_size, emb_size]
|
115 |
-
# Standard nn.TransformerDecoder expects tgt_mask [tgt_len, tgt_len]
|
116 |
-
# Standard nn.TransformerDecoder expects memory_key_padding_mask [batch_size, src_len]
|
117 |
-
# Standard nn.TransformerDecoder expects tgt_key_padding_mask [batch_size, tgt_len]
|
118 |
-
|
119 |
tgt_seq_len = ys.shape[1]
|
120 |
-
# Create causal mask -> [tgt_len, tgt_len]
|
121 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
122 |
device
|
123 |
-
)
|
124 |
-
|
125 |
-
# Target padding mask: False for non-pad tokens -> [1, tgt_len]
|
126 |
tgt_padding_mask = torch.zeros(
|
127 |
ys.shape, dtype=torch.bool, device=device
|
128 |
-
)
|
129 |
-
|
130 |
-
# Prepare target for decoder (assuming batch_first=False expected)
|
131 |
-
# Input ys is [1, current_len] -> need [current_len, 1]
|
132 |
-
ys_decoder_input = ys.transpose(0, 1).to(device) # [current_len, 1]
|
133 |
|
134 |
# Decode one step
|
135 |
decoder_output = transformer_model.decode(
|
136 |
-
tgt=
|
137 |
-
memory=memory,
|
138 |
-
tgt_mask=tgt_mask,
|
139 |
-
|
140 |
-
memory_key_padding_mask=memory_key_padding_mask,
|
141 |
-
)
|
142 |
|
143 |
# Get logits for the *next* token prediction
|
144 |
-
# Use output corresponding to the last input token -> [-1, :, :]
|
145 |
next_token_logits = transformer_model.generator(
|
146 |
-
decoder_output[
|
147 |
-
|
|
|
|
|
148 |
|
149 |
# Find the most likely next token (greedy choice)
|
150 |
-
|
|
|
|
|
151 |
next_word_id = next_word_id_tensor.item()
|
152 |
|
153 |
-
# Append the chosen token to the sequence
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
# Stop if EOS token is generated
|
160 |
if next_word_id == eos_idx:
|
161 |
break
|
162 |
|
163 |
# Return the generated sequence (excluding the initial SOS token)
|
164 |
-
# Shape [1, generated_len]
|
165 |
-
return ys[:, 1:]
|
166 |
|
167 |
except RuntimeError as e:
|
168 |
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
|
169 |
if "CUDA out of memory" in str(e) and device.type == "cuda":
|
170 |
gc.collect()
|
171 |
torch.cuda.empty_cache()
|
172 |
-
|
173 |
-
|
|
|
174 |
except Exception as e:
|
175 |
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
|
176 |
-
|
177 |
|
178 |
|
179 |
# --- Translation Function (Using Greedy Decode) ---
|
@@ -183,108 +159,94 @@ def translate(
|
|
183 |
smiles_tokenizer: Tokenizer,
|
184 |
iupac_tokenizer: Tokenizer,
|
185 |
device: torch.device,
|
186 |
-
|
187 |
sos_idx: int,
|
188 |
eos_idx: int,
|
189 |
pad_idx: int,
|
190 |
-
) -> str: # Returns a single string
|
191 |
"""
|
192 |
Translates a single SMILES string using greedy decoding.
|
193 |
"""
|
194 |
-
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
195 |
-
return "[Initialization Error: Components not loaded]"
|
196 |
-
|
197 |
model.eval() # Ensure model is in eval mode
|
198 |
|
199 |
# --- Tokenize Source ---
|
200 |
try:
|
201 |
-
# Ensure tokenizer has truncation configured
|
202 |
-
smiles_tokenizer.enable_truncation(
|
203 |
-
|
204 |
-
|
205 |
src_encoded = smiles_tokenizer.encode(src_sentence)
|
206 |
if not src_encoded or not src_encoded.ids:
|
207 |
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
208 |
-
return "[Encoding Error
|
|
|
209 |
src_ids = src_encoded.ids
|
210 |
-
# Use attention mask directly for padding mask (1 for real tokens, 0 for padding)
|
211 |
-
# We need the opposite for PyTorch Transformer (True for padding, False for real)
|
212 |
-
src_attention_mask = torch.tensor(src_encoded.attention_mask, dtype=torch.long)
|
213 |
-
src_padding_mask = (src_attention_mask == 0) # True where it's padded
|
214 |
-
|
215 |
except Exception as e:
|
216 |
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
|
217 |
-
return
|
218 |
|
219 |
# --- Prepare Input Tensor and Mask ---
|
220 |
-
|
221 |
-
|
222 |
-
#
|
223 |
-
|
|
|
224 |
|
225 |
# --- Perform Greedy Decoding ---
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
|
242 |
# --- Decode Generated Tokens ---
|
243 |
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
else:
|
249 |
-
logging.warning(
|
250 |
-
f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
|
251 |
-
)
|
252 |
-
return "[Decoding Error: Empty Output]"
|
253 |
|
254 |
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
255 |
try:
|
256 |
# Decode using the target tokenizer, skipping special tokens
|
257 |
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
258 |
-
return translation
|
259 |
except Exception as e:
|
260 |
logging.error(
|
261 |
f"Error decoding target tokens {tgt_tokens}: {e}",
|
262 |
exc_info=True,
|
263 |
)
|
264 |
-
return "[Decoding Error
|
265 |
|
266 |
|
267 |
-
# --- Model/Tokenizer Loading Function ---
|
268 |
def load_model_and_tokenizers():
|
269 |
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
270 |
-
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
271 |
if model is not None: # Already loaded
|
272 |
logging.info("Model and tokenizers already loaded.")
|
273 |
return
|
274 |
|
275 |
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
|
276 |
-
|
277 |
-
# --- Check if helper code loaded ---
|
278 |
-
if SmilesIupacLitModule is None or generate_square_subsequent_mask is None:
|
279 |
-
error_msg = f"Initialization Error: Could not load required components from enhanced_trainer.py. Check Space logs and ensure the file exists in the repo root."
|
280 |
-
logging.error(error_msg)
|
281 |
-
raise gr.Error(error_msg)
|
282 |
-
|
283 |
try:
|
284 |
# Determine device
|
285 |
if torch.cuda.is_available():
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
288 |
else:
|
289 |
device = torch.device("cpu")
|
290 |
logging.info("CUDA not available, using CPU.")
|
@@ -292,25 +254,27 @@ def load_model_and_tokenizers():
|
|
292 |
# Download files
|
293 |
logging.info("Downloading files from Hugging Face Hub...")
|
294 |
try:
|
295 |
-
# Define cache directory, default to './hf_cache' if GRADIO_CACHE is not set
|
296 |
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
|
297 |
-
os.makedirs(cache_dir, exist_ok=True)
|
298 |
logging.info(f"Using cache directory: {cache_dir}")
|
299 |
|
300 |
-
# Download files to the specified cache directory
|
301 |
checkpoint_path = hf_hub_download(
|
302 |
-
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
|
303 |
)
|
304 |
smiles_tokenizer_path = hf_hub_download(
|
305 |
-
repo_id=MODEL_REPO_ID,
|
|
|
|
|
306 |
)
|
307 |
iupac_tokenizer_path = hf_hub_download(
|
308 |
-
repo_id=MODEL_REPO_ID,
|
|
|
|
|
309 |
)
|
310 |
config_path = hf_hub_download(
|
311 |
-
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
|
312 |
)
|
313 |
-
logging.info("Files downloaded
|
314 |
except Exception as e:
|
315 |
logging.error(
|
316 |
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
|
@@ -328,35 +292,30 @@ def load_model_and_tokenizers():
|
|
328 |
logging.info("Configuration loaded.")
|
329 |
# --- Validate essential config keys ---
|
330 |
required_keys = [
|
331 |
-
"src_vocab_size",
|
332 |
-
"tgt_vocab_size",
|
333 |
"emb_size",
|
334 |
"nhead",
|
335 |
"ffn_hid_dim",
|
336 |
"num_encoder_layers",
|
337 |
"num_decoder_layers",
|
338 |
"dropout",
|
339 |
-
"max_len",
|
340 |
"pad_token_id",
|
341 |
"bos_token_id",
|
342 |
"eos_token_id",
|
343 |
]
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
config
|
355 |
-
config['pad_token_id'] = config.get('pad_token_id', config.get('PAD_IDX'))
|
356 |
-
config['bos_token_id'] = config.get('bos_token_id', config.get('BOS_IDX'))
|
357 |
-
config['eos_token_id'] = config.get('eos_token_id', config.get('EOS_IDX'))
|
358 |
-
# Add UNK if needed by your model/tokenizer setup
|
359 |
-
# config['unk_token_id'] = config.get('unk_token_id', config.get('UNK_IDX', 0)) # Default to 0 if missing? Risky.
|
360 |
|
361 |
missing_keys = [key for key in required_keys if config.get(key) is None]
|
362 |
if missing_keys:
|
@@ -366,7 +325,7 @@ def load_model_and_tokenizers():
|
|
366 |
)
|
367 |
|
368 |
logging.info(
|
369 |
-
f"Using config: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
|
370 |
f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
|
371 |
f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
|
372 |
)
|
@@ -391,75 +350,51 @@ def load_model_and_tokenizers():
|
|
391 |
try:
|
392 |
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
|
393 |
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
394 |
-
|
395 |
-
# --- Validate Tokenizer Special Tokens Against Config ---
|
|
|
396 |
pad_token = "<pad>"
|
397 |
sos_token = "<sos>"
|
398 |
eos_token = "<eos>"
|
399 |
-
unk_token = "<unk>"
|
400 |
issues = []
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
iupac_unk_id = iupac_tokenizer.token_to_id(unk_token)
|
415 |
-
if iupac_pad_id is None or iupac_pad_id != config["pad_token_id"]:
|
416 |
-
issues.append(f"IUPAC PAD ID mismatch (Tokenizer: {iupac_pad_id}, Config: {config['pad_token_id']})")
|
417 |
-
if iupac_sos_id is None or iupac_sos_id != config["bos_token_id"]:
|
418 |
-
issues.append(f"IUPAC SOS ID mismatch (Tokenizer: {iupac_sos_id}, Config: {config['bos_token_id']})")
|
419 |
-
if iupac_eos_id is None or iupac_eos_id != config["eos_token_id"]:
|
420 |
-
issues.append(f"IUPAC EOS ID mismatch (Tokenizer: {iupac_eos_id}, Config: {config['eos_token_id']})")
|
421 |
-
if iupac_unk_id is None:
|
422 |
-
issues.append("IUPAC UNK token not found in tokenizer")
|
423 |
-
|
424 |
if issues:
|
425 |
-
logging.warning("Tokenizer validation issues
|
426 |
-
# Decide if this is critical. For inference, SOS/EOS/PAD matches are most important.
|
427 |
-
# raise gr.Error("Tokenizer Validation Error: Mismatch between config and tokenizer files. Check logs.")
|
428 |
-
else:
|
429 |
-
logging.info("Tokenizers loaded and special tokens validated against config.")
|
430 |
|
431 |
except Exception as e:
|
432 |
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
|
433 |
raise gr.Error(
|
434 |
-
f"Tokenizer Error: Could not load tokenizers. Check logs
|
435 |
)
|
436 |
|
437 |
# Load model
|
438 |
logging.info("Loading model from checkpoint...")
|
439 |
try:
|
440 |
-
#
|
441 |
-
# Make sure SmilesIupacLitModule's __init__ accepts these keys
|
442 |
-
model_instance = SmilesIupacLitModule(**config)
|
443 |
-
|
444 |
-
# Load the state dict from the checkpoint onto the instance
|
445 |
-
# Use load_state_dict for more control if load_from_checkpoint causes issues
|
446 |
-
# state_dict = torch.load(checkpoint_path, map_location=device)['state_dict']
|
447 |
-
# model_instance.load_state_dict(state_dict, strict=True) # Try strict=False if needed
|
448 |
-
|
449 |
-
# Use load_from_checkpoint (simpler if it works)
|
450 |
model = SmilesIupacLitModule.load_from_checkpoint(
|
451 |
checkpoint_path,
|
|
|
452 |
map_location=device,
|
453 |
-
|
454 |
-
#
|
455 |
-
**config, # Try removing this if you instantiate first
|
456 |
-
strict=True # Start strict, set to False ONLY if necessary and you understand why
|
457 |
)
|
458 |
|
459 |
-
|
460 |
model.to(device)
|
461 |
model.eval()
|
462 |
-
model.freeze()
|
463 |
logging.info(
|
464 |
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
|
465 |
)
|
@@ -467,203 +402,172 @@ def load_model_and_tokenizers():
|
|
467 |
except FileNotFoundError:
|
468 |
logging.error(f"Checkpoint file not found: {checkpoint_path}")
|
469 |
raise gr.Error(
|
470 |
-
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found
|
471 |
)
|
472 |
-
except
|
473 |
-
logging.error(f"Runtime error loading model checkpoint {checkpoint_path}: {e}", exc_info=True)
|
474 |
-
if "size mismatch" in str(e):
|
475 |
-
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint structure. Or config doesn't match model definition."
|
476 |
-
logging.error(error_detail)
|
477 |
-
raise gr.Error(f"Model Load Error: {error_detail} Original error: {e}")
|
478 |
-
elif "CUDA out of memory" in str(e) or "memory" in str(e).lower():
|
479 |
-
logging.warning("Potential OOM error during model loading.")
|
480 |
-
gc.collect()
|
481 |
-
if device.type == "cuda": torch.cuda.empty_cache()
|
482 |
-
raise gr.Error(f"Model Load Error: OOM loading model. Check Space resources. Error: {e}")
|
483 |
-
else:
|
484 |
-
raise gr.Error(f"Model Load Error: Runtime error. Check logs. Error: {e}")
|
485 |
-
except Exception as e: # Catch other potential errors during loading
|
486 |
logging.error(
|
487 |
-
f"
|
488 |
-
)
|
489 |
-
raise gr.Error(
|
490 |
-
f"Model Load Error: Failed to load checkpoint for unknown reason. Check logs. Error: {e}"
|
491 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
|
493 |
-
except gr.Error
|
494 |
-
raise
|
495 |
-
except Exception as e:
|
496 |
-
logging.error(f"Unexpected error during loading
|
497 |
raise gr.Error(
|
498 |
-
f"Initialization Error: Unexpected
|
499 |
)
|
500 |
|
501 |
|
502 |
-
# --- Inference Function for Gradio ---
|
503 |
def predict_iupac(smiles_string):
|
504 |
"""
|
505 |
Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
|
506 |
-
Handles input validation, canonicalization, translation, and output formatting.
|
507 |
"""
|
|
|
|
|
|
|
|
|
|
|
508 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
509 |
|
510 |
-
# --- Check Initialization ---
|
511 |
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
512 |
-
error_msg = "Error: Model or tokenizers not loaded. App initialization failed. Check Space logs."
|
513 |
logging.error(error_msg)
|
514 |
-
# Return
|
515 |
-
return error_msg
|
516 |
|
517 |
-
# --- Input Validation ---
|
518 |
if not smiles_string or not smiles_string.strip():
|
519 |
-
|
|
|
520 |
|
521 |
-
|
522 |
|
523 |
-
# --- Canonicalize SMILES (Optional but Recommended) ---
|
524 |
-
smiles_input_canon = smiles_input_raw
|
525 |
-
if Chem:
|
526 |
-
try:
|
527 |
-
mol = Chem.MolFromSmiles(smiles_input_raw)
|
528 |
-
if mol:
|
529 |
-
smiles_input_canon = Chem.MolToSmiles(mol, canonical=True)
|
530 |
-
logging.info(f"Canonicalized SMILES: {smiles_input_raw} -> {smiles_input_canon}")
|
531 |
-
else:
|
532 |
-
# RDKit couldn't parse it, proceed with raw input but warn
|
533 |
-
logging.warning(f"Could not parse SMILES '{smiles_input_raw}' with RDKit. Using raw input.")
|
534 |
-
# Optionally return an error here if strict parsing is needed
|
535 |
-
# return f"Error: Invalid SMILES string '{smiles_input_raw}' according to RDKit."
|
536 |
-
except Exception as e:
|
537 |
-
logging.error(f"Error during RDKit canonicalization for '{smiles_input_raw}': {e}", exc_info=True)
|
538 |
-
# Proceed with raw input, maybe add note to output
|
539 |
-
# return f"Error: RDKit processing failed: {e}" # Option to fail hard
|
540 |
-
|
541 |
-
# --- Translation ---
|
542 |
try:
|
|
|
543 |
sos_idx = config["bos_token_id"]
|
544 |
eos_idx = config["eos_token_id"]
|
545 |
pad_idx = config["pad_token_id"]
|
546 |
-
gen_max_len = config["max_len"]
|
547 |
|
548 |
-
predicted_name = translate(
|
549 |
model=model,
|
550 |
-
src_sentence=
|
551 |
smiles_tokenizer=smiles_tokenizer,
|
552 |
iupac_tokenizer=iupac_tokenizer,
|
553 |
device=device,
|
554 |
-
|
555 |
sos_idx=sos_idx,
|
556 |
eos_idx=eos_idx,
|
557 |
pad_idx=pad_idx,
|
558 |
)
|
559 |
-
logging.info(f"
|
560 |
|
561 |
# --- Format Output ---
|
562 |
-
# Check
|
563 |
-
if predicted_name.startswith("[") and predicted_name.endswith("]"):
|
564 |
-
# Assume it's an error/warning message from translate()
|
565 |
output_text = (
|
566 |
-
f"Input SMILES: {
|
567 |
-
f"(Raw Input: {smiles_input_raw})\n\n" # Show raw if canonicalization happened
|
568 |
-
f"Prediction Failed: {predicted_name}"
|
569 |
-
)
|
570 |
-
elif not predicted_name: # Handle empty string case
|
571 |
-
output_text = (
|
572 |
-
f"Input SMILES: {smiles_input_canon}\n"
|
573 |
-
f"(Raw Input: {smiles_input_raw})\n\n"
|
574 |
-
f"Prediction: [No name generated]"
|
575 |
)
|
|
|
|
|
576 |
else:
|
577 |
output_text = (
|
578 |
-
f"Input SMILES: {
|
579 |
-
f"(Raw Input: {smiles_input_raw})\n\n"
|
580 |
f"Predicted IUPAC Name (Greedy Decode):\n"
|
581 |
f"{predicted_name}"
|
582 |
)
|
583 |
-
|
584 |
-
if smiles_input_raw == smiles_input_canon:
|
585 |
-
output_text = output_text.replace(f"(Raw Input: {smiles_input_raw})\n", "")
|
586 |
|
587 |
-
|
|
|
|
|
588 |
|
589 |
except Exception as e:
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
return error_msg
|
594 |
|
595 |
|
596 |
# --- Load Model on App Start ---
|
597 |
-
# Wrap in try/except to allow Gradio UI to potentially display an error
|
598 |
-
model_load_error = None
|
599 |
try:
|
600 |
load_model_and_tokenizers()
|
601 |
except gr.Error as ge:
|
602 |
-
logging.error(f"Gradio Initialization Error
|
603 |
-
|
604 |
except Exception as e:
|
605 |
-
logging.
|
606 |
-
|
607 |
|
608 |
|
609 |
-
# --- Create Gradio Interface ---
|
610 |
title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
|
611 |
description = f"""
|
612 |
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
|
613 |
Translation uses **greedy decoding** (picks the most likely next word at each step).
|
|
|
614 |
"""
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
else:
|
620 |
-
description += f"\n**Note:** Device information unavailable (loading might have failed)."
|
621 |
-
|
622 |
-
# Use gr.Blocks for layout
|
623 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
|
624 |
gr.Markdown(f"# {title}")
|
625 |
gr.Markdown(description)
|
626 |
-
|
627 |
with gr.Row():
|
628 |
-
with gr.Column(
|
629 |
smiles_input = gr.Textbox(
|
630 |
label="SMILES String",
|
631 |
-
placeholder="Enter SMILES string (e.g., CCO
|
632 |
-
lines=
|
633 |
)
|
634 |
-
|
635 |
-
|
636 |
-
|
|
|
|
|
637 |
output_text = gr.Textbox(
|
638 |
-
label="
|
639 |
-
lines=
|
640 |
show_copy_button=True,
|
641 |
-
interactive=False, # Output box is not for user input
|
642 |
)
|
643 |
-
|
644 |
-
#
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
fn=predict_iupac, # Function to run for examples
|
656 |
-
cache_examples=False, # Re-run examples each time if needed, True might speed up demo loading
|
657 |
-
)
|
658 |
-
|
659 |
-
# Connect the button click and input change events
|
660 |
-
submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text, api_name="translate_smiles")
|
661 |
-
# Optionally trigger on text change (can be slow/resource intensive)
|
662 |
-
# smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
|
663 |
|
664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
# --- Launch the App ---
|
666 |
if __name__ == "__main__":
|
667 |
-
|
668 |
-
# Set debug=True for more detailed Gradio errors during development
|
669 |
-
iface.launch(share=False, debug=False)
|
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
+
|
5 |
# import torch.nn.functional as F # No longer needed for greedy decode directly
|
6 |
import pytorch_lightning as pl
|
7 |
import os
|
|
|
10 |
from tokenizers import Tokenizer
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
import gc
|
13 |
+
from rdkit.Chem import CanonSmiles
|
14 |
+
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# --- Configuration ---
|
17 |
MODEL_REPO_ID = (
|
|
|
30 |
|
31 |
# --- Load Helper Code (Only Model Definition and Mask Function Needed) ---
|
32 |
try:
|
|
|
33 |
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
34 |
+
|
35 |
logging.info("Successfully imported from enhanced_trainer.py.")
|
36 |
except ImportError as e:
|
37 |
logging.error(
|
38 |
f"Failed to import helper code from enhanced_trainer.py: {e}. "
|
39 |
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'."
|
40 |
)
|
41 |
+
raise gr.Error(
|
42 |
+
f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}"
|
43 |
+
)
|
44 |
except Exception as e:
|
45 |
logging.error(
|
46 |
f"An unexpected error occurred during helper code import: {e}", exc_info=True
|
47 |
)
|
48 |
+
raise gr.Error(
|
49 |
+
f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}"
|
50 |
+
)
|
51 |
|
52 |
# --- Global Variables (Load Model Once) ---
|
53 |
model: pl.LightningModule | None = None
|
|
|
69 |
) -> torch.Tensor:
|
70 |
"""
|
71 |
Performs greedy decoding using the LightningModule's model.
|
|
|
72 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
model.eval() # Ensure model is in evaluation mode
|
74 |
transformer_model = model.model # Access the underlying Seq2SeqTransformer
|
75 |
|
76 |
try:
|
77 |
with torch.no_grad():
|
78 |
# --- Encode Source ---
|
|
|
79 |
memory = transformer_model.encode(
|
80 |
+
src, src_padding_mask
|
81 |
+
) # [1, src_len, emb_size]
|
|
|
|
|
82 |
memory = memory.to(device)
|
83 |
+
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
|
|
|
|
|
84 |
|
85 |
# --- Initialize Target Sequence ---
|
86 |
+
# Start with the SOS token
|
87 |
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
88 |
sos_idx
|
89 |
+
) # [1, 1]
|
90 |
|
91 |
# --- Decoding Loop ---
|
92 |
+
for _ in range(max_len - 1): # Max length limit
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
tgt_seq_len = ys.shape[1]
|
|
|
94 |
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
95 |
device
|
96 |
+
) # [curr_len, curr_len]
|
97 |
+
# No padding in target during generation yet
|
|
|
98 |
tgt_padding_mask = torch.zeros(
|
99 |
ys.shape, dtype=torch.bool, device=device
|
100 |
+
) # [1, curr_len]
|
|
|
|
|
|
|
|
|
101 |
|
102 |
# Decode one step
|
103 |
decoder_output = transformer_model.decode(
|
104 |
+
tgt=ys,
|
105 |
+
memory=memory,
|
106 |
+
tgt_mask=tgt_mask,
|
107 |
+
tgt_padding_mask=tgt_padding_mask,
|
108 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
109 |
+
) # [1, curr_len, emb_size]
|
110 |
|
111 |
# Get logits for the *next* token prediction
|
|
|
112 |
next_token_logits = transformer_model.generator(
|
113 |
+
decoder_output[
|
114 |
+
:, -1, :
|
115 |
+
] # Use output corresponding to the last input token
|
116 |
+
) # [1, tgt_vocab_size]
|
117 |
|
118 |
# Find the most likely next token (greedy choice)
|
119 |
+
# prob = F.log_softmax(next_token_logits, dim=-1) # Not needed for argmax
|
120 |
+
# _, next_word_id_tensor = torch.max(prob, dim=1)
|
121 |
+
next_word_id_tensor = torch.argmax(next_token_logits, dim=1) # [1]
|
122 |
next_word_id = next_word_id_tensor.item()
|
123 |
|
124 |
+
# Append the chosen token to the sequence
|
125 |
+
ys = torch.cat(
|
126 |
+
[
|
127 |
+
ys,
|
128 |
+
torch.ones(1, 1, dtype=torch.long, device=device).fill_(
|
129 |
+
next_word_id
|
130 |
+
),
|
131 |
+
],
|
132 |
+
dim=1,
|
133 |
+
) # [1, current_len + 1]
|
134 |
|
135 |
# Stop if EOS token is generated
|
136 |
if next_word_id == eos_idx:
|
137 |
break
|
138 |
|
139 |
# Return the generated sequence (excluding the initial SOS token)
|
140 |
+
return ys[:, 1:] # Shape [1, generated_len]
|
|
|
141 |
|
142 |
except RuntimeError as e:
|
143 |
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True)
|
144 |
if "CUDA out of memory" in str(e) and device.type == "cuda":
|
145 |
gc.collect()
|
146 |
torch.cuda.empty_cache()
|
147 |
+
return torch.empty(
|
148 |
+
(1, 0), dtype=torch.long, device=device
|
149 |
+
) # Return empty tensor on error
|
150 |
except Exception as e:
|
151 |
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True)
|
152 |
+
return torch.empty((1, 0), dtype=torch.long, device=device)
|
153 |
|
154 |
|
155 |
# --- Translation Function (Using Greedy Decode) ---
|
|
|
159 |
smiles_tokenizer: Tokenizer,
|
160 |
iupac_tokenizer: Tokenizer,
|
161 |
device: torch.device,
|
162 |
+
max_len: int,
|
163 |
sos_idx: int,
|
164 |
eos_idx: int,
|
165 |
pad_idx: int,
|
166 |
+
) -> str: # Returns a single string
|
167 |
"""
|
168 |
Translates a single SMILES string using greedy decoding.
|
169 |
"""
|
|
|
|
|
|
|
170 |
model.eval() # Ensure model is in eval mode
|
171 |
|
172 |
# --- Tokenize Source ---
|
173 |
try:
|
174 |
+
# Ensure tokenizer has truncation/padding configured if needed, or handle manually
|
175 |
+
smiles_tokenizer.enable_truncation(
|
176 |
+
max_length=max_len
|
177 |
+
) # Use max_len for source truncation too
|
178 |
src_encoded = smiles_tokenizer.encode(src_sentence)
|
179 |
if not src_encoded or not src_encoded.ids:
|
180 |
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
181 |
+
return "[Encoding Error]"
|
182 |
+
# Use the truncated IDs directly
|
183 |
src_ids = src_encoded.ids
|
|
|
|
|
|
|
|
|
|
|
184 |
except Exception as e:
|
185 |
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True)
|
186 |
+
return "[Encoding Error]"
|
187 |
|
188 |
# --- Prepare Input Tensor and Mask ---
|
189 |
+
src = (
|
190 |
+
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
|
191 |
+
) # [1, src_len]
|
192 |
+
# Create padding mask (True where it's a pad token, should be all False here unless tokenizer pads)
|
193 |
+
src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
|
194 |
|
195 |
# --- Perform Greedy Decoding ---
|
196 |
+
# Calls the greedy_decode function defined *above in this file*
|
197 |
+
# Note: max_len for generation should come from config if it dictates output length
|
198 |
+
generation_max_len = config.get(
|
199 |
+
"max_len", 256
|
200 |
+
) # Use config max_len for output limit
|
201 |
+
tgt_tokens_tensor = greedy_decode(
|
202 |
+
model=model,
|
203 |
+
src=src,
|
204 |
+
src_padding_mask=src_padding_mask,
|
205 |
+
max_len=generation_max_len, # Use generation limit
|
206 |
+
sos_idx=sos_idx,
|
207 |
+
eos_idx=eos_idx,
|
208 |
+
# pad_idx=pad_idx, # Not needed by greedy_decode internal loop
|
209 |
+
device=device,
|
210 |
+
) # Returns a single tensor [1, generated_len]
|
211 |
|
212 |
# --- Decode Generated Tokens ---
|
213 |
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0:
|
214 |
+
logging.warning(
|
215 |
+
f"Greedy decode returned empty tensor for SMILES: {src_sentence}"
|
216 |
+
)
|
217 |
+
return "[Decoding Error - Empty Output]"
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
220 |
try:
|
221 |
# Decode using the target tokenizer, skipping special tokens
|
222 |
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
223 |
+
return translation
|
224 |
except Exception as e:
|
225 |
logging.error(
|
226 |
f"Error decoding target tokens {tgt_tokens}: {e}",
|
227 |
exc_info=True,
|
228 |
)
|
229 |
+
return "[Decoding Error]"
|
230 |
|
231 |
|
232 |
+
# --- Model/Tokenizer Loading Function (Unchanged) ---
|
233 |
def load_model_and_tokenizers():
|
234 |
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
235 |
+
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
236 |
if model is not None: # Already loaded
|
237 |
logging.info("Model and tokenizers already loaded.")
|
238 |
return
|
239 |
|
240 |
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
try:
|
242 |
# Determine device
|
243 |
if torch.cuda.is_available():
|
244 |
+
logging.warning(
|
245 |
+
"CUDA is available, but forcing CPU for Gradio app simplicity. Modify if GPU is intended."
|
246 |
+
)
|
247 |
+
device = torch.device("cpu") # Uncomment if GPU is intended
|
248 |
+
# device = torch.device("cuda")
|
249 |
+
# logging.info("CUDA available, using GPU.")
|
250 |
else:
|
251 |
device = torch.device("cpu")
|
252 |
logging.info("CUDA not available, using CPU.")
|
|
|
254 |
# Download files
|
255 |
logging.info("Downloading files from Hugging Face Hub...")
|
256 |
try:
|
|
|
257 |
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache")
|
258 |
+
os.makedirs(cache_dir, exist_ok=True)
|
259 |
logging.info(f"Using cache directory: {cache_dir}")
|
260 |
|
|
|
261 |
checkpoint_path = hf_hub_download(
|
262 |
+
repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir
|
263 |
)
|
264 |
smiles_tokenizer_path = hf_hub_download(
|
265 |
+
repo_id=MODEL_REPO_ID,
|
266 |
+
filename=SMILES_TOKENIZER_FILENAME,
|
267 |
+
cache_dir=cache_dir,
|
268 |
)
|
269 |
iupac_tokenizer_path = hf_hub_download(
|
270 |
+
repo_id=MODEL_REPO_ID,
|
271 |
+
filename=IUPAC_TOKENIZER_FILENAME,
|
272 |
+
cache_dir=cache_dir,
|
273 |
)
|
274 |
config_path = hf_hub_download(
|
275 |
+
repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir
|
276 |
)
|
277 |
+
logging.info("Files downloaded successfully.")
|
278 |
except Exception as e:
|
279 |
logging.error(
|
280 |
f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}",
|
|
|
292 |
logging.info("Configuration loaded.")
|
293 |
# --- Validate essential config keys ---
|
294 |
required_keys = [
|
295 |
+
"src_vocab_size", # Use the key saved in config
|
296 |
+
"tgt_vocab_size", # Use the key saved in config
|
297 |
"emb_size",
|
298 |
"nhead",
|
299 |
"ffn_hid_dim",
|
300 |
"num_encoder_layers",
|
301 |
"num_decoder_layers",
|
302 |
"dropout",
|
303 |
+
"max_len",
|
304 |
"pad_token_id",
|
305 |
"bos_token_id",
|
306 |
"eos_token_id",
|
307 |
]
|
308 |
+
# Remap if needed (example shown, adjust if your keys differ)
|
309 |
+
config_key_mapping = {
|
310 |
+
"src_vocab_size": config.get(
|
311 |
+
"src_vocab_size", config.get("src_vocab_size")
|
312 |
+
),
|
313 |
+
"tgt_vocab_size": config.get(
|
314 |
+
"tgt_vocab_size", config.get("tgt_vocab_size")
|
315 |
+
),
|
316 |
+
# Add other mappings if necessary
|
317 |
+
}
|
318 |
+
config.update(config_key_mapping)
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
missing_keys = [key for key in required_keys if config.get(key) is None]
|
321 |
if missing_keys:
|
|
|
325 |
)
|
326 |
|
327 |
logging.info(
|
328 |
+
f"Using config values: src_vocab={config['src_vocab_size']}, tgt_vocab={config['tgt_vocab_size']}, "
|
329 |
f"emb={config['emb_size']}, nhead={config['nhead']}, enc={config['num_encoder_layers']}, dec={config['num_decoder_layers']}, "
|
330 |
f"pad={config['pad_token_id']}, sos={config['bos_token_id']}, eos={config['eos_token_id']}, max_len={config['max_len']}"
|
331 |
)
|
|
|
350 |
try:
|
351 |
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
|
352 |
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
353 |
+
logging.info("Tokenizers loaded.")
|
354 |
+
# --- Optional: Validate Tokenizer Special Tokens Against Config ---
|
355 |
+
# (Keep validation as before, it's still useful)
|
356 |
pad_token = "<pad>"
|
357 |
sos_token = "<sos>"
|
358 |
eos_token = "<eos>"
|
359 |
+
unk_token = "<unk>"
|
360 |
issues = []
|
361 |
+
# ... (keep the validation checks as in the original code) ...
|
362 |
+
if smiles_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
363 |
+
issues.append(f"SMILES PAD ID mismatch")
|
364 |
+
if smiles_tokenizer.token_to_id(unk_token) is None:
|
365 |
+
issues.append("SMILES UNK token not found")
|
366 |
+
if iupac_tokenizer.token_to_id(pad_token) != config["pad_token_id"]:
|
367 |
+
issues.append(f"IUPAC PAD ID mismatch")
|
368 |
+
if iupac_tokenizer.token_to_id(sos_token) != config["bos_token_id"]:
|
369 |
+
issues.append(f"IUPAC SOS ID mismatch")
|
370 |
+
if iupac_tokenizer.token_to_id(eos_token) != config["eos_token_id"]:
|
371 |
+
issues.append(f"IUPAC EOS ID mismatch")
|
372 |
+
if iupac_tokenizer.token_to_id(unk_token) is None:
|
373 |
+
issues.append("IUPAC UNK token not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
if issues:
|
375 |
+
logging.warning("Tokenizer validation issues: " + "; ".join(issues))
|
|
|
|
|
|
|
|
|
376 |
|
377 |
except Exception as e:
|
378 |
logging.error(f"Failed to load tokenizers: {e}", exc_info=True)
|
379 |
raise gr.Error(
|
380 |
+
f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}"
|
381 |
)
|
382 |
|
383 |
# Load model
|
384 |
logging.info("Loading model from checkpoint...")
|
385 |
try:
|
386 |
+
# Use the vocab sizes and hparams from the loaded config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
model = SmilesIupacLitModule.load_from_checkpoint(
|
388 |
checkpoint_path,
|
389 |
+
**config, # Pass loaded config directly as keyword args
|
390 |
map_location=device,
|
391 |
+
devices=1,
|
392 |
+
strict=True, # Start strict, set to False if encountering key errors
|
|
|
|
|
393 |
)
|
394 |
|
|
|
395 |
model.to(device)
|
396 |
model.eval()
|
397 |
+
model.freeze()
|
398 |
logging.info(
|
399 |
f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'."
|
400 |
)
|
|
|
402 |
except FileNotFoundError:
|
403 |
logging.error(f"Checkpoint file not found: {checkpoint_path}")
|
404 |
raise gr.Error(
|
405 |
+
f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found."
|
406 |
)
|
407 |
+
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
logging.error(
|
409 |
+
f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True
|
|
|
|
|
|
|
410 |
)
|
411 |
+
if "size mismatch" in str(e):
|
412 |
+
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint."
|
413 |
+
logging.error(error_detail)
|
414 |
+
raise gr.Error(f"Model Error: {error_detail} Original error: {e}")
|
415 |
+
elif "memory" in str(e).lower():
|
416 |
+
logging.warning("Potential OOM error during model loading.")
|
417 |
+
gc.collect()
|
418 |
+
torch.cuda.empty_cache() if device.type == "cuda" else None
|
419 |
+
raise gr.Error(
|
420 |
+
f"Model Error: OOM loading model. Check Space resources. Error: {e}"
|
421 |
+
)
|
422 |
+
else:
|
423 |
+
raise gr.Error(
|
424 |
+
f"Model Error: Failed to load checkpoint. Check logs. Error: {e}"
|
425 |
+
)
|
426 |
|
427 |
+
except gr.Error:
|
428 |
+
raise
|
429 |
+
except Exception as e:
|
430 |
+
logging.error(f"Unexpected error during loading: {e}", exc_info=True)
|
431 |
raise gr.Error(
|
432 |
+
f"Initialization Error: Unexpected error. Check logs. Error: {e}"
|
433 |
)
|
434 |
|
435 |
|
436 |
+
# --- Inference Function for Gradio (Simplified) ---
|
437 |
def predict_iupac(smiles_string):
|
438 |
"""
|
439 |
Performs SMILES to IUPAC translation using the loaded model and greedy decoding.
|
|
|
440 |
"""
|
441 |
+
try:
|
442 |
+
smiles_string = CanonSmiles(smiles_string)
|
443 |
+
except Exception as e:
|
444 |
+
logging.error(f"Error during SMILES canonicalization: {e}", exc_info=True)
|
445 |
+
return f"Error: {e}"
|
446 |
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
447 |
|
|
|
448 |
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
449 |
+
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs."
|
450 |
logging.error(error_msg)
|
451 |
+
return f"Error: {error_msg}" # Return single error string
|
|
|
452 |
|
|
|
453 |
if not smiles_string or not smiles_string.strip():
|
454 |
+
error_msg = "Error: Please enter a valid SMILES string."
|
455 |
+
return f"Error: {error_msg}" # Return single error string
|
456 |
|
457 |
+
smiles_input = smiles_string.strip()
|
458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
try:
|
460 |
+
# --- Call the core translation logic (greedy) ---
|
461 |
sos_idx = config["bos_token_id"]
|
462 |
eos_idx = config["eos_token_id"]
|
463 |
pad_idx = config["pad_token_id"]
|
464 |
+
gen_max_len = config["max_len"]
|
465 |
|
466 |
+
predicted_name = translate( # Returns a single string now
|
467 |
model=model,
|
468 |
+
src_sentence=smiles_input,
|
469 |
smiles_tokenizer=smiles_tokenizer,
|
470 |
iupac_tokenizer=iupac_tokenizer,
|
471 |
device=device,
|
472 |
+
max_len=gen_max_len,
|
473 |
sos_idx=sos_idx,
|
474 |
eos_idx=eos_idx,
|
475 |
pad_idx=pad_idx,
|
476 |
)
|
477 |
+
logging.info(f"Prediction returned: {predicted_name}")
|
478 |
|
479 |
# --- Format Output ---
|
480 |
+
if "[Error]" in predicted_name: # Check for error messages from translate
|
|
|
|
|
481 |
output_text = (
|
482 |
+
f"Input SMILES: {smiles_input}\n\nPrediction Failed: {predicted_name}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
)
|
484 |
+
elif not predicted_name:
|
485 |
+
output_text = f"Input SMILES: {smiles_input}\n\nNo prediction generated (decoding might have failed)."
|
486 |
else:
|
487 |
output_text = (
|
488 |
+
f"Input SMILES: {smiles_input}\n\n"
|
|
|
489 |
f"Predicted IUPAC Name (Greedy Decode):\n"
|
490 |
f"{predicted_name}"
|
491 |
)
|
492 |
+
return output_text
|
|
|
|
|
493 |
|
494 |
+
except RuntimeError as e:
|
495 |
+
logging.error(f"Runtime error during translation: {e}", exc_info=True)
|
496 |
+
return f"Error: {error_msg}" # Return single error string
|
497 |
|
498 |
except Exception as e:
|
499 |
+
logging.error(f"Unexpected error during translation: {e}", exc_info=True)
|
500 |
+
error_msg = f"Unexpected Error during translation: {e}"
|
501 |
+
return f"Error: {error_msg}" # Return single error string
|
|
|
502 |
|
503 |
|
504 |
# --- Load Model on App Start ---
|
|
|
|
|
505 |
try:
|
506 |
load_model_and_tokenizers()
|
507 |
except gr.Error as ge:
|
508 |
+
logging.error(f"Gradio Initialization Error: {ge}")
|
509 |
+
pass # Allow Gradio to potentially start with an error message
|
510 |
except Exception as e:
|
511 |
+
logging.error(f"Critical error during initial model loading: {e}", exc_info=True)
|
512 |
+
# Optionally raise gr.Error here too
|
513 |
|
514 |
|
515 |
+
# --- Create Gradio Interface (Simplified) ---
|
516 |
title = "SMILES to IUPAC Name Translator (Greedy Decoding)"
|
517 |
description = f"""
|
518 |
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}) trained via PyTorch Lightning.
|
519 |
Translation uses **greedy decoding** (picks the most likely next word at each step).
|
520 |
+
**Note:** Model loaded on **{str(device).upper() if device else "N/A"}**. Performance may vary. Check `config.json` in the repo for model details.
|
521 |
"""
|
522 |
+
|
523 |
+
|
524 |
+
|
525 |
+
# Replace your Interface code with this:
|
|
|
|
|
|
|
|
|
526 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface:
|
527 |
gr.Markdown(f"# {title}")
|
528 |
gr.Markdown(description)
|
529 |
+
|
530 |
with gr.Row():
|
531 |
+
with gr.Column():
|
532 |
smiles_input = gr.Textbox(
|
533 |
label="SMILES String",
|
534 |
+
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
|
535 |
+
lines=1,
|
536 |
)
|
537 |
+
|
538 |
+
# Add a button for manual triggering
|
539 |
+
submit_btn = gr.Button("Translate")
|
540 |
+
|
541 |
+
with gr.Column():
|
542 |
output_text = gr.Textbox(
|
543 |
+
label="Predicted IUPAC Name",
|
544 |
+
lines=3,
|
545 |
show_copy_button=True,
|
|
|
546 |
)
|
547 |
+
|
548 |
+
# Connect the events properly
|
549 |
+
submit_btn.click(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
|
550 |
+
|
551 |
+
# If you still want change event (auto-translation as user types)
|
552 |
+
smiles_input.change(fn=predict_iupac, inputs=smiles_input, outputs=output_text)
|
553 |
+
# Output component
|
554 |
+
output_text = gr.Textbox(
|
555 |
+
label="Predicted IUPAC Name",
|
556 |
+
lines=3,
|
557 |
+
show_copy_button=True, # Reduced lines slightly
|
558 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
|
560 |
|
561 |
+
# Create the interface instance
|
562 |
+
iface = gr.Interface(
|
563 |
+
fn=predict_iupac, # The function to call
|
564 |
+
inputs=smiles_input, # Input component directly
|
565 |
+
outputs=output_text, # Output component
|
566 |
+
title=title,
|
567 |
+
description=description,
|
568 |
+
allow_flagging="never",
|
569 |
+
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
|
570 |
+
)
|
571 |
# --- Launch the App ---
|
572 |
if __name__ == "__main__":
|
573 |
+
iface.launch()
|
|
|
|