|
|
|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
import pytorch_lightning as pl |
|
import os |
|
import json |
|
import logging |
|
from tokenizers import Tokenizer |
|
from huggingface_hub import hf_hub_download |
|
import gc |
|
from rdkit.Chem import CanonSmiles, MolFromSmiles |
|
import spaces |
|
import heapq |
|
import math |
|
|
|
|
|
MODEL_REPO_ID = ( |
|
"AdrianM0/smiles-to-iupac-translator" |
|
) |
|
CHECKPOINT_FILENAME = "last.ckpt" |
|
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json" |
|
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json" |
|
CONFIG_FILENAME = "config.json" |
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
try: |
|
|
|
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask |
|
|
|
logging.info("Successfully imported from enhanced_trainer.py.") |
|
except ImportError as e: |
|
logging.error( |
|
f"Failed to import helper code from enhanced_trainer.py: {e}. " |
|
f"Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'." |
|
) |
|
raise gr.Error( |
|
f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}" |
|
) |
|
except Exception as e: |
|
logging.error( |
|
f"An unexpected error occurred during helper code import: {e}", exc_info=True |
|
) |
|
raise gr.Error( |
|
f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}" |
|
) |
|
|
|
|
|
model: pl.LightningModule | None = None |
|
smiles_tokenizer: Tokenizer | None = None |
|
iupac_tokenizer: Tokenizer | None = None |
|
device: torch.device | None = None |
|
config: dict | None = None |
|
|
|
|
|
|
|
def greedy_decode( |
|
model: pl.LightningModule, |
|
src: torch.Tensor, |
|
src_padding_mask: torch.Tensor, |
|
max_len: int, |
|
sos_idx: int, |
|
eos_idx: int, |
|
device: torch.device, |
|
) -> torch.Tensor: |
|
""" |
|
Performs greedy decoding using the LightningModule's model. |
|
Returns a tensor of shape [1, sequence_length]. |
|
""" |
|
model.eval() |
|
transformer_model = model.model |
|
|
|
try: |
|
with torch.no_grad(): |
|
memory = transformer_model.encode(src, src_padding_mask) |
|
memory = memory.to(device) |
|
memory_key_padding_mask = src_padding_mask.to(memory.device) |
|
|
|
ys = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) |
|
|
|
for _ in range(max_len - 1): |
|
tgt_seq_len = ys.shape[1] |
|
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) |
|
tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device) |
|
|
|
decoder_output = transformer_model.decode( |
|
tgt=ys, |
|
memory=memory, |
|
tgt_mask=tgt_mask, |
|
tgt_padding_mask=tgt_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
) |
|
|
|
next_token_logits = transformer_model.generator(decoder_output[:, -1, :]) |
|
next_word_id = torch.argmax(next_token_logits, dim=1).item() |
|
|
|
ys = torch.cat( |
|
[ |
|
ys, |
|
torch.ones(1, 1, dtype=torch.long, device=device).fill_(next_word_id), |
|
], |
|
dim=1, |
|
) |
|
|
|
if next_word_id == eos_idx: |
|
break |
|
|
|
return ys[:, 1:] |
|
|
|
except RuntimeError as e: |
|
logging.error(f"Runtime error during greedy decode: {e}", exc_info=True) |
|
if "CUDA out of memory" in str(e) and device.type == "cuda": |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return torch.empty((1, 0), dtype=torch.long, device=device) |
|
except Exception as e: |
|
logging.error(f"Unexpected error during greedy decode: {e}", exc_info=True) |
|
return torch.empty((1, 0), dtype=torch.long, device=device) |
|
|
|
|
|
|
|
def beam_search_decode( |
|
model: pl.LightningModule, |
|
src: torch.Tensor, |
|
src_padding_mask: torch.Tensor, |
|
max_len: int, |
|
sos_idx: int, |
|
eos_idx: int, |
|
pad_idx: int, |
|
device: torch.device, |
|
beam_width: int, |
|
num_return_sequences: int = 1, |
|
length_penalty_alpha: float = 0.6, |
|
) -> list[tuple[torch.Tensor, float]]: |
|
""" |
|
Performs beam search decoding. |
|
Returns a list of tuples: (sequence_tensor [1, seq_len], score) |
|
""" |
|
model.eval() |
|
transformer_model = model.model |
|
num_return_sequences = min(beam_width, num_return_sequences) |
|
|
|
try: |
|
with torch.no_grad(): |
|
|
|
memory = transformer_model.encode(src, src_padding_mask) |
|
memory = memory.to(device) |
|
memory_key_padding_mask = src_padding_mask.to(memory.device) |
|
|
|
|
|
|
|
initial_beam = ( |
|
torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx), |
|
0.0, |
|
) |
|
beams = [initial_beam] |
|
finished_hypotheses = [] |
|
|
|
|
|
for step in range(max_len - 1): |
|
if not beams: |
|
break |
|
|
|
|
|
|
|
candidates = [] |
|
|
|
|
|
|
|
for beam_idx, (current_seq, current_score) in enumerate(beams): |
|
if current_seq[0, -1].item() == eos_idx: |
|
|
|
penalty = ((current_seq.shape[1]) ** length_penalty_alpha) |
|
final_score = current_score / penalty if penalty > 0 else current_score |
|
heapq.heappush(finished_hypotheses, (final_score, current_seq)) |
|
|
|
while len(finished_hypotheses) > beam_width: |
|
heapq.heappop(finished_hypotheses) |
|
continue |
|
|
|
|
|
ys = current_seq |
|
tgt_seq_len = ys.shape[1] |
|
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) |
|
tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool, device=device) |
|
|
|
|
|
|
|
decoder_output = transformer_model.decode( |
|
tgt=ys, |
|
memory=memory, |
|
tgt_mask=tgt_mask, |
|
tgt_padding_mask=tgt_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
) |
|
|
|
|
|
next_token_logits = transformer_model.generator( |
|
decoder_output[:, -1, :] |
|
) |
|
|
|
|
|
log_probs = F.log_softmax(next_token_logits, dim=-1) |
|
|
|
|
|
|
|
top_k_log_probs, top_k_indices = torch.topk(log_probs + current_score, beam_width, dim=1) |
|
|
|
|
|
for i in range(beam_width): |
|
token_id = top_k_indices[0, i].item() |
|
score = top_k_log_probs[0, i].item() |
|
|
|
heapq.heappush(candidates, (-score, token_id, beam_idx)) |
|
|
|
|
|
|
|
|
|
new_beams = [] |
|
|
|
num_candidates_to_consider = min(len(candidates), beam_width * len(beams)) |
|
|
|
|
|
top_candidates = heapq.nsmallest(beam_width, candidates) |
|
|
|
added_sequences = set() |
|
|
|
for neg_score, token_id, beam_idx in top_candidates: |
|
original_seq, _ = beams[beam_idx] |
|
new_seq = torch.cat( |
|
[ |
|
original_seq, |
|
torch.ones(1, 1, dtype=torch.long, device=device).fill_(token_id), |
|
], |
|
dim=1, |
|
) |
|
|
|
|
|
seq_tuple = tuple(new_seq.flatten().tolist()) |
|
if seq_tuple not in added_sequences: |
|
new_beams.append((new_seq, -neg_score)) |
|
added_sequences.add(seq_tuple) |
|
|
|
beams = new_beams |
|
|
|
|
|
if finished_hypotheses: |
|
|
|
best_active_score = -heapq.nsmallest(1, candidates)[0][0] if candidates else -float('inf') |
|
worst_finished_score = finished_hypotheses[0][0] |
|
if len(finished_hypotheses) >= num_return_sequences and best_active_score < worst_finished_score: |
|
logging.debug(f"Beam search early stopping at step {step}") |
|
break |
|
|
|
|
|
|
|
|
|
for seq, score in beams: |
|
if seq[0, -1].item() != eos_idx: |
|
penalty = ((seq.shape[1]) ** length_penalty_alpha) |
|
final_score = score / penalty if penalty > 0 else score |
|
heapq.heappush(finished_hypotheses, (final_score, seq)) |
|
while len(finished_hypotheses) > beam_width: |
|
heapq.heappop(finished_hypotheses) |
|
|
|
|
|
|
|
top_hypotheses = heapq.nlargest(num_return_sequences, finished_hypotheses) |
|
|
|
|
|
return [(seq[:, 1:], score) for score, seq in top_hypotheses] |
|
|
|
except RuntimeError as e: |
|
logging.error(f"Runtime error during beam search: {e}", exc_info=True) |
|
if "CUDA out of memory" in str(e) and device.type == "cuda": |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return [] |
|
except Exception as e: |
|
logging.error(f"Unexpected error during beam search: {e}", exc_info=True) |
|
return [] |
|
|
|
|
|
|
|
def translate( |
|
model: pl.LightningModule, |
|
src_sentence: str, |
|
smiles_tokenizer: Tokenizer, |
|
iupac_tokenizer: Tokenizer, |
|
device: torch.device, |
|
max_len: int, |
|
sos_idx: int, |
|
eos_idx: int, |
|
pad_idx: int, |
|
decoding_strategy: str = "Greedy", |
|
beam_width: int = 5, |
|
num_return_sequences: int = 1, |
|
length_penalty_alpha: float = 0.6, |
|
) -> list[tuple[str, float]]: |
|
""" |
|
Translates a single SMILES string using the specified decoding strategy. |
|
""" |
|
model.eval() |
|
|
|
|
|
try: |
|
smiles_tokenizer.enable_truncation(max_length=max_len) |
|
src_encoded = smiles_tokenizer.encode(src_sentence) |
|
if not src_encoded or not src_encoded.ids: |
|
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}") |
|
return [("[Encoding Error]", 0.0)] |
|
src_ids = src_encoded.ids |
|
except Exception as e: |
|
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}", exc_info=True) |
|
return [("[Encoding Error]", 0.0)] |
|
|
|
|
|
src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) |
|
src_padding_mask = (src == pad_idx).to(device) |
|
|
|
|
|
generation_max_len = config.get("max_len", 256) |
|
results = [] |
|
|
|
if decoding_strategy == "Greedy": |
|
tgt_tokens_tensor = greedy_decode( |
|
model=model, |
|
src=src, |
|
src_padding_mask=src_padding_mask, |
|
max_len=generation_max_len, |
|
sos_idx=sos_idx, |
|
eos_idx=eos_idx, |
|
device=device, |
|
) |
|
if tgt_tokens_tensor is not None and tgt_tokens_tensor.numel() > 0: |
|
results = [(tgt_tokens_tensor, 0.0)] |
|
else: |
|
logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}") |
|
return [("[Decoding Error - Empty Output]", 0.0)] |
|
|
|
elif decoding_strategy == "Beam Search": |
|
results = beam_search_decode( |
|
model=model, |
|
src=src, |
|
src_padding_mask=src_padding_mask, |
|
max_len=generation_max_len, |
|
sos_idx=sos_idx, |
|
eos_idx=eos_idx, |
|
pad_idx=pad_idx, |
|
device=device, |
|
beam_width=beam_width, |
|
num_return_sequences=num_return_sequences, |
|
length_penalty_alpha=length_penalty_alpha, |
|
) |
|
if not results: |
|
logging.warning(f"Beam search returned no results for SMILES: {src_sentence}") |
|
return [("[Decoding Error - Empty Output]", 0.0)] |
|
else: |
|
logging.error(f"Unknown decoding strategy: {decoding_strategy}") |
|
return [("[Error: Unknown Strategy]", 0.0)] |
|
|
|
|
|
|
|
translations = [] |
|
for tgt_tokens_tensor, score in results: |
|
if tgt_tokens_tensor is None or tgt_tokens_tensor.numel() == 0: |
|
translations.append(("[Decoding Error - Empty Sequence]", score)) |
|
continue |
|
|
|
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist() |
|
try: |
|
|
|
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True) |
|
translations.append((translation, score)) |
|
except Exception as e: |
|
logging.error( |
|
f"Error decoding target tokens {tgt_tokens}: {e}", |
|
exc_info=True, |
|
) |
|
translations.append(("[Decoding Error]", score)) |
|
|
|
return translations |
|
|
|
|
|
|
|
def load_model_and_tokenizers(): |
|
"""Loads tokenizers, config, and model from Hugging Face Hub.""" |
|
global model, smiles_tokenizer, iupac_tokenizer, device, config |
|
if model is not None: |
|
logging.info("Model and tokenizers already loaded.") |
|
return |
|
|
|
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...") |
|
try: |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
logging.info("Using CPU. Modify code to enable GPU if available and desired.") |
|
|
|
|
|
logging.info("Downloading files from Hugging Face Hub...") |
|
cache_dir = os.environ.get("GRADIO_CACHE", "./hf_cache") |
|
os.makedirs(cache_dir, exist_ok=True) |
|
logging.info(f"Using cache directory: {cache_dir}") |
|
|
|
try: |
|
checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME, cache_dir=cache_dir) |
|
smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME, cache_dir=cache_dir) |
|
iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME, cache_dir=cache_dir) |
|
config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME, cache_dir=cache_dir) |
|
|
|
try: |
|
hf_hub_download(repo_id=MODEL_REPO_ID, filename="enhanced_trainer.py", cache_dir=cache_dir, local_dir=".") |
|
logging.info("Downloaded enhanced_trainer.py") |
|
except Exception as download_err: |
|
if os.path.exists("enhanced_trainer.py"): |
|
logging.warning(f"Could not download enhanced_trainer.py (maybe private?), but found local file. Using local. Error: {download_err}") |
|
else: |
|
raise download_err |
|
|
|
logging.info("Files downloaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True) |
|
raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}") |
|
|
|
|
|
logging.info("Loading configuration...") |
|
try: |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
logging.info("Configuration loaded.") |
|
required_keys = ["src_vocab_size", "tgt_vocab_size", "emb_size", "nhead", "ffn_hid_dim", "num_encoder_layers", "num_decoder_layers", "dropout", "max_len", "pad_token_id", "bos_token_id", "eos_token_id"] |
|
missing_keys = [key for key in required_keys if config.get(key) is None] |
|
if missing_keys: |
|
raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}.") |
|
logging.info(f"Using config: { {k: config.get(k) for k in required_keys} }") |
|
except Exception as e: |
|
logging.error(f"Error loading or validating config: {e}", exc_info=True) |
|
raise gr.Error(f"Config Error: {e}") |
|
|
|
|
|
|
|
logging.info("Loading tokenizers...") |
|
try: |
|
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path) |
|
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path) |
|
|
|
if smiles_tokenizer.get_vocab_size() != config['src_vocab_size']: |
|
logging.warning(f"SMILES vocab size mismatch: Tokenizer={smiles_tokenizer.get_vocab_size()}, Config={config['src_vocab_size']}") |
|
if iupac_tokenizer.get_vocab_size() != config['tgt_vocab_size']: |
|
logging.warning(f"IUPAC vocab size mismatch: Tokenizer={iupac_tokenizer.get_vocab_size()}, Config={config['tgt_vocab_size']}") |
|
logging.info("Tokenizers loaded.") |
|
except Exception as e: |
|
logging.error(f"Failed to load tokenizers: {e}", exc_info=True) |
|
raise gr.Error(f"Tokenizer Error: Could not load tokenizers. Check logs. Error: {e}") |
|
|
|
|
|
logging.info("Loading model from checkpoint...") |
|
try: |
|
|
|
|
|
model_hparams = config.copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expected_args = SmilesIupacLitModule.__init__.__code__.co_varnames |
|
hparams_to_pass = {k: v for k, v in model_hparams.items() if k in expected_args} |
|
logging.info(f"Passing hparams to LitModule: {hparams_to_pass.keys()}") |
|
|
|
|
|
model = SmilesIupacLitModule.load_from_checkpoint( |
|
checkpoint_path, |
|
map_location=device, |
|
|
|
strict=False, |
|
**hparams_to_pass |
|
) |
|
|
|
model.to(device) |
|
model.eval() |
|
model.freeze() |
|
logging.info(f"Model loaded successfully from {checkpoint_path}, set to eval mode, frozen, and moved to device '{device}'.") |
|
|
|
except FileNotFoundError: |
|
logging.error(f"Checkpoint file not found: {checkpoint_path}") |
|
raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.") |
|
except Exception as e: |
|
logging.error(f"Error loading model checkpoint {checkpoint_path}: {e}", exc_info=True) |
|
if "size mismatch" in str(e): |
|
error_detail = f"Potential size mismatch. Check vocab sizes in config.json (src={config.get('src_vocab_size')}, tgt={config.get('tgt_vocab_size')}) vs checkpoint." |
|
logging.error(error_detail) |
|
raise gr.Error(f"Model Error: {error_detail} Original error: {e}") |
|
elif "unexpected keyword argument" in str(e) or "missing 1 required positional argument" in str(e): |
|
error_detail = f"Mismatch between config.json keys and SmilesIupacLitModule constructor arguments. Check enhanced_trainer.py and config.json. Error: {e}" |
|
logging.error(error_detail) |
|
raise gr.Error(f"Model Error: {error_detail}") |
|
elif "memory" in str(e).lower(): |
|
logging.warning("Potential OOM error during model loading.") |
|
gc.collect() |
|
torch.cuda.empty_cache() if device.type == "cuda" else None |
|
raise gr.Error(f"Model Error: OOM loading model. Check Space resources. Error: {e}") |
|
else: |
|
raise gr.Error(f"Model Error: Failed to load checkpoint. Check logs. Error: {e}") |
|
|
|
except gr.Error: |
|
raise |
|
except Exception as e: |
|
logging.error(f"Unexpected error during loading: {e}", exc_info=True) |
|
raise gr.Error(f"Initialization Error: Unexpected error. Check logs. Error: {e}") |
|
|
|
|
|
|
|
@spaces.GPU |
|
def predict_iupac(smiles_string, decoding_strategy, num_beams, num_return_sequences): |
|
""" |
|
Performs SMILES to IUPAC translation using the loaded model and selected strategy. |
|
""" |
|
global model, smiles_tokenizer, iupac_tokenizer, device, config |
|
|
|
|
|
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]): |
|
error_msg = "Error: Model or tokenizers not loaded properly. App initialization might have failed. Check Space logs." |
|
logging.error(error_msg) |
|
return f"Initialization Error: {error_msg}" |
|
|
|
if not smiles_string or not smiles_string.strip(): |
|
return "Error: Please enter a valid SMILES string." |
|
|
|
smiles_input = smiles_string.strip() |
|
|
|
|
|
try: |
|
mol = MolFromSmiles(smiles_input) |
|
if mol is None: |
|
return f"Error: Invalid SMILES string provided: '{smiles_input}'" |
|
smiles_input = CanonSmiles(smiles_input) |
|
logging.info(f"Canonical SMILES: {smiles_input}") |
|
except Exception as e: |
|
logging.error(f"Error during SMILES validation/canonicalization: {e}", exc_info=True) |
|
return f"Error: Could not process SMILES string '{smiles_input}'. RDKit error: {e}" |
|
|
|
|
|
if decoding_strategy == "Beam Search": |
|
if not isinstance(num_beams, int) or num_beams <= 0: |
|
return "Error: Beam width must be a positive integer." |
|
if not isinstance(num_return_sequences, int) or num_return_sequences <= 0: |
|
return "Error: Number of return sequences must be a positive integer." |
|
if num_return_sequences > num_beams: |
|
return f"Error: Number of return sequences ({num_return_sequences}) cannot exceed beam width ({num_beams})." |
|
else: |
|
|
|
num_beams = 1 |
|
num_return_sequences = 1 |
|
|
|
|
|
try: |
|
|
|
sos_idx = config["bos_token_id"] |
|
eos_idx = config["eos_token_id"] |
|
pad_idx = config["pad_token_id"] |
|
gen_max_len = config["max_len"] |
|
|
|
length_penalty = 0.6 |
|
|
|
predicted_results = translate( |
|
model=model, |
|
src_sentence=smiles_input, |
|
smiles_tokenizer=smiles_tokenizer, |
|
iupac_tokenizer=iupac_tokenizer, |
|
device=device, |
|
max_len=gen_max_len, |
|
sos_idx=sos_idx, |
|
eos_idx=eos_idx, |
|
pad_idx=pad_idx, |
|
decoding_strategy=decoding_strategy, |
|
beam_width=num_beams, |
|
num_return_sequences=num_return_sequences, |
|
length_penalty_alpha=length_penalty, |
|
) |
|
logging.info(f"Prediction returned {len(predicted_results)} result(s). Strategy: {decoding_strategy}, Beams: {num_beams}, Return: {num_return_sequences}") |
|
|
|
|
|
output_lines = [] |
|
output_lines.append(f"Input SMILES: {smiles_input}") |
|
output_lines.append(f"Decoding Strategy: {decoding_strategy}") |
|
if decoding_strategy == "Beam Search": |
|
output_lines.append(f"Beam Width: {num_beams}") |
|
output_lines.append(f"Returned Sequences: {len(predicted_results)}") |
|
output_lines.append(f"Length Penalty Alpha: {length_penalty:.2f}") |
|
|
|
|
|
output_lines.append("\n--- Predictions ---") |
|
|
|
if not predicted_results: |
|
output_lines.append("No predictions generated.") |
|
else: |
|
for i, (name, score) in enumerate(predicted_results): |
|
if "[Error]" in name or not name: |
|
output_lines.append(f"{i+1}. Prediction Failed: {name}") |
|
else: |
|
score_info = f"(Score: {score:.4f})" if decoding_strategy == "Beam Search" else "" |
|
output_lines.append(f"{i+1}. {name} {score_info}") |
|
|
|
return "\n".join(output_lines) |
|
|
|
except RuntimeError as e: |
|
logging.error(f"Runtime error during translation: {e}", exc_info=True) |
|
gc.collect() |
|
if device.type == 'cuda': torch.cuda.empty_cache() |
|
return f"Runtime Error during translation: {e}. Check logs." |
|
except Exception as e: |
|
logging.error(f"Unexpected error during translation: {e}", exc_info=True) |
|
return f"Unexpected Error during translation: {e}. Check logs." |
|
|
|
|
|
|
|
try: |
|
load_model_and_tokenizers() |
|
except gr.Error as ge: |
|
|
|
logging.error(f"Gradio Initialization Error during load: {ge}") |
|
|
|
|
|
except Exception as e: |
|
logging.error(f"Critical error during initial model loading: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
|
title = "SMILES to IUPAC Name Translator" |
|
description = f""" |
|
Translate a SMILES string into its IUPAC chemical name using a Transformer model ({MODEL_REPO_ID}). |
|
Choose between **Greedy Decoding** (fastest, picks the most likely next word) and **Beam Search Decoding** (explores multiple possibilities, potentially better results, slower). |
|
**Note:** Model loaded on **{str(device).upper() if device else 'N/A'}**. Beam search can be slow, especially with larger beam widths. |
|
Check `config.json` in the repo for model details. SMILES input will be canonicalized using RDKit. |
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as iface: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
smiles_input = gr.Textbox( |
|
label="SMILES String", |
|
placeholder="Enter SMILES string (e.g., CCO, c1ccccc1)", |
|
lines=2, |
|
) |
|
with gr.Accordion("Decoding Options", open=False): |
|
decode_strategy = gr.Radio( |
|
["Greedy", "Beam Search"], |
|
label="Decoding Strategy", |
|
value="Greedy", |
|
info="Greedy is faster, Beam Search may be more accurate." |
|
) |
|
beam_width_slider = gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=5, |
|
label="Beam Width", |
|
info="Number of beams to explore (Beam Search only)", |
|
visible=False |
|
) |
|
num_seq_slider = gr.Slider( |
|
minimum=1, |
|
maximum=5, |
|
step=1, |
|
value=1, |
|
label="Number of Results", |
|
info="How many sequences to return (Beam Search only)", |
|
visible=False |
|
) |
|
|
|
submit_btn = gr.Button("Translate", variant="primary") |
|
|
|
|
|
def update_beam_options(strategy): |
|
is_beam = strategy == "Beam Search" |
|
return { |
|
beam_width_slider: gr.update(visible=is_beam), |
|
num_seq_slider: gr.update(visible=is_beam) |
|
} |
|
|
|
decode_strategy.change( |
|
fn=update_beam_options, |
|
inputs=decode_strategy, |
|
outputs=[beam_width_slider, num_seq_slider] |
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
output_text = gr.Textbox( |
|
label="Translation Results", |
|
lines=10, |
|
show_copy_button=True, |
|
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=predict_iupac, |
|
inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], |
|
outputs=output_text, |
|
api_name="translate_smiles" |
|
) |
|
|
|
|
|
smiles_input.submit( |
|
fn=predict_iupac, |
|
inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], |
|
outputs=output_text |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["CCO", "Greedy", 1, 1], |
|
["c1ccccc1", "Greedy", 1, 1], |
|
["CC(C)Br", "Beam Search", 5, 3], |
|
["C[C@H](O)c1ccccc1", "Beam Search", 10, 5], |
|
["INVALID_SMILES", "Greedy", 1, 1], |
|
["N#CC(C)(C)OC(=O)C(C)=C", "Beam Search", 8, 2] |
|
], |
|
inputs=[smiles_input, decode_strategy, beam_width_slider, num_seq_slider], |
|
outputs=output_text, |
|
fn=predict_iupac, |
|
cache_examples=False, |
|
label="Example SMILES & Settings" |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |