Upload folder using huggingface_hub
Browse files- README.md +2 -8
- app.py +503 -0
- enhanced_trainer.py +1421 -0
- requirements.txt +16 -0
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.25.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: smi2iupac
|
3 |
+
app_file: app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.25.2
|
|
|
|
|
6 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F # <--- Added import
|
5 |
+
import pytorch_lightning as pl # <--- Added import (needed for type hints, model access)
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from tokenizers import Tokenizer
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
import gc # For garbage collection on potential OOM
|
12 |
+
import math # Needed for PositionalEncoding if moved here (or keep in enhanced_trainer)
|
13 |
+
|
14 |
+
# --- Configuration ---
|
15 |
+
MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator"
|
16 |
+
CHECKPOINT_FILENAME = "last.ckpt"
|
17 |
+
SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
|
18 |
+
IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
|
19 |
+
CONFIG_FILENAME = "config.json"
|
20 |
+
# --- End Configuration ---
|
21 |
+
|
22 |
+
# --- Logging ---
|
23 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
24 |
+
|
25 |
+
# --- Load Helper Code (Only Model Definition Needed) ---
|
26 |
+
try:
|
27 |
+
# We only need the LightningModule definition and the mask function now
|
28 |
+
from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
|
29 |
+
logging.info("Successfully imported from enhanced_trainer.py.")
|
30 |
+
|
31 |
+
# We will define beam_search_decode and translate locally in this file
|
32 |
+
# REMOVED: from test_ckpt import beam_search_decode, translate
|
33 |
+
|
34 |
+
except ImportError as e:
|
35 |
+
logging.error(f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'.")
|
36 |
+
gr.Error(f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}")
|
37 |
+
exit()
|
38 |
+
except Exception as e:
|
39 |
+
logging.error(f"An unexpected error occurred during helper code import: {e}", exc_info=True)
|
40 |
+
gr.Error(f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}")
|
41 |
+
exit()
|
42 |
+
|
43 |
+
# --- Global Variables (Load Model Once) ---
|
44 |
+
model: pl.LightningModule | None = None # Added type hint
|
45 |
+
smiles_tokenizer: Tokenizer | None = None
|
46 |
+
iupac_tokenizer: Tokenizer | None = None
|
47 |
+
device: torch.device | None = None
|
48 |
+
config: dict | None = None
|
49 |
+
|
50 |
+
# --- Beam Search Decoding Logic (Moved from test_ckpt.py) ---
|
51 |
+
|
52 |
+
def beam_search_decode(
|
53 |
+
model: pl.LightningModule,
|
54 |
+
src: torch.Tensor,
|
55 |
+
src_padding_mask: torch.Tensor,
|
56 |
+
max_len: int,
|
57 |
+
sos_idx: int,
|
58 |
+
eos_idx: int,
|
59 |
+
pad_idx: int, # Needed for padding mask check if src has padding
|
60 |
+
device: torch.device,
|
61 |
+
beam_width: int = 5,
|
62 |
+
n_best: int = 5, # Number of top sequences to return
|
63 |
+
length_penalty: float = 0.6 # Alpha for length normalization (0=no penalty, 1=full penalty)
|
64 |
+
) -> list[torch.Tensor]:
|
65 |
+
"""
|
66 |
+
Performs beam search decoding using the LightningModule's model.
|
67 |
+
(Code copied and pasted from test_ckpt.py)
|
68 |
+
"""
|
69 |
+
# Ensure model is in eval mode (redundant if called after model.eval(), but safe)
|
70 |
+
model.eval()
|
71 |
+
transformer_model = model.model # Access the underlying Seq2SeqTransformer
|
72 |
+
n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences
|
73 |
+
|
74 |
+
try:
|
75 |
+
with torch.no_grad():
|
76 |
+
# --- Encode Source ---
|
77 |
+
memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size]
|
78 |
+
memory = memory.to(device)
|
79 |
+
# Ensure memory_key_padding_mask is also on the correct device for decode
|
80 |
+
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
|
81 |
+
|
82 |
+
# --- Initialize Beams ---
|
83 |
+
initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) # [1, 1]
|
84 |
+
initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
|
85 |
+
active_beams = [(initial_beam_seq, initial_beam_score)]
|
86 |
+
finished_beams = []
|
87 |
+
|
88 |
+
# --- Decoding Loop ---
|
89 |
+
for step in range(max_len - 1):
|
90 |
+
if not active_beams:
|
91 |
+
break
|
92 |
+
|
93 |
+
potential_next_beams = []
|
94 |
+
for current_seq, current_score in active_beams:
|
95 |
+
if current_seq[0, -1].item() == eos_idx:
|
96 |
+
finished_beams.append((current_seq, current_score))
|
97 |
+
continue
|
98 |
+
|
99 |
+
tgt_input = current_seq # [1, current_len]
|
100 |
+
tgt_seq_len = tgt_input.shape[1]
|
101 |
+
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) # [curr_len, curr_len]
|
102 |
+
tgt_padding_mask = torch.zeros(tgt_input.shape, dtype=torch.bool, device=device) # [1, curr_len]
|
103 |
+
|
104 |
+
decoder_output = transformer_model.decode(
|
105 |
+
tgt=tgt_input,
|
106 |
+
memory=memory,
|
107 |
+
tgt_mask=tgt_mask,
|
108 |
+
tgt_padding_mask=tgt_padding_mask,
|
109 |
+
memory_key_padding_mask=memory_key_padding_mask
|
110 |
+
) # [1, curr_len, emb_size]
|
111 |
+
|
112 |
+
next_token_logits = transformer_model.generator(decoder_output[:, -1, :]) # [1, tgt_vocab_size]
|
113 |
+
log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size]
|
114 |
+
|
115 |
+
topk_log_probs, topk_indices = torch.topk(log_probs + current_score, beam_width, dim=-1)
|
116 |
+
|
117 |
+
for i in range(beam_width):
|
118 |
+
next_token_id = topk_indices[0, i].item()
|
119 |
+
next_score = topk_log_probs[0, i].reshape(1) # Keep as tensor [1]
|
120 |
+
next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=device) # [1, 1]
|
121 |
+
new_seq = torch.cat([current_seq, next_token_tensor], dim=1) # [1, current_len + 1]
|
122 |
+
potential_next_beams.append((new_seq, next_score))
|
123 |
+
|
124 |
+
potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
|
125 |
+
|
126 |
+
active_beams = []
|
127 |
+
added_count = 0
|
128 |
+
for seq, score in potential_next_beams:
|
129 |
+
is_finished = seq[0, -1].item() == eos_idx
|
130 |
+
if is_finished:
|
131 |
+
finished_beams.append((seq, score))
|
132 |
+
elif added_count < beam_width:
|
133 |
+
active_beams.append((seq, score))
|
134 |
+
added_count += 1
|
135 |
+
elif added_count >= beam_width:
|
136 |
+
break
|
137 |
+
|
138 |
+
finished_beams.extend(active_beams)
|
139 |
+
|
140 |
+
# Apply length penalty and sort
|
141 |
+
# Handle potential division by zero if sequence length is 1 (or 0?)
|
142 |
+
def get_score(beam_tuple):
|
143 |
+
seq, score = beam_tuple
|
144 |
+
seq_len = seq.shape[1]
|
145 |
+
if length_penalty == 0.0 or seq_len <= 1:
|
146 |
+
return score.item()
|
147 |
+
else:
|
148 |
+
# Ensure seq_len is float for pow
|
149 |
+
return score.item() / (float(seq_len) ** length_penalty)
|
150 |
+
|
151 |
+
finished_beams.sort(key=get_score, reverse=True) # Higher score is better
|
152 |
+
|
153 |
+
top_sequences = [seq[:, 1:] for seq, score in finished_beams[:n_best]] # seq shape [1, len] -> [1, len-1]
|
154 |
+
return top_sequences
|
155 |
+
|
156 |
+
except RuntimeError as e:
|
157 |
+
logging.error(f"Runtime error during beam search decode: {e}")
|
158 |
+
if "CUDA out of memory" in str(e):
|
159 |
+
gc.collect(); torch.cuda.empty_cache()
|
160 |
+
return [] # Return empty list on error
|
161 |
+
except Exception as e:
|
162 |
+
logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
|
163 |
+
return []
|
164 |
+
|
165 |
+
# --- Translation Function (Moved from test_ckpt.py) ---
|
166 |
+
|
167 |
+
def translate(
|
168 |
+
model: pl.LightningModule,
|
169 |
+
src_sentence: str,
|
170 |
+
smiles_tokenizer: Tokenizer,
|
171 |
+
iupac_tokenizer: Tokenizer,
|
172 |
+
device: torch.device,
|
173 |
+
max_len: int,
|
174 |
+
sos_idx: int,
|
175 |
+
eos_idx: int,
|
176 |
+
pad_idx: int,
|
177 |
+
beam_width: int = 5,
|
178 |
+
n_best: int = 5,
|
179 |
+
length_penalty: float = 0.6
|
180 |
+
) -> list[str]:
|
181 |
+
"""
|
182 |
+
Translates a single SMILES string using beam search.
|
183 |
+
(Code copied and pasted from test_ckpt.py)
|
184 |
+
"""
|
185 |
+
model.eval() # Ensure model is in eval mode
|
186 |
+
translations = []
|
187 |
+
|
188 |
+
# --- Tokenize Source ---
|
189 |
+
try:
|
190 |
+
src_encoded = smiles_tokenizer.encode(src_sentence)
|
191 |
+
if not src_encoded or not src_encoded.ids:
|
192 |
+
logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
|
193 |
+
return ["[Encoding Error]"] * n_best
|
194 |
+
src_ids = src_encoded.ids[:max_len] # Truncate source
|
195 |
+
if not src_ids:
|
196 |
+
logging.warning(f"Source empty after truncation: {src_sentence}")
|
197 |
+
return ["[Encoding Error - Empty Src]"] * n_best
|
198 |
+
except Exception as e:
|
199 |
+
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
|
200 |
+
return ["[Encoding Error]"] * n_best
|
201 |
+
|
202 |
+
# --- Prepare Input Tensor and Mask ---
|
203 |
+
src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len]
|
204 |
+
src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
|
205 |
+
|
206 |
+
# --- Perform Beam Search Decoding ---
|
207 |
+
# Calls the beam_search_decode function defined above in this file
|
208 |
+
tgt_tokens_list = beam_search_decode(
|
209 |
+
model=model,
|
210 |
+
src=src,
|
211 |
+
src_padding_mask=src_padding_mask,
|
212 |
+
max_len=max_len,
|
213 |
+
sos_idx=sos_idx,
|
214 |
+
eos_idx=eos_idx,
|
215 |
+
pad_idx=pad_idx,
|
216 |
+
device=device,
|
217 |
+
beam_width=beam_width,
|
218 |
+
n_best=n_best,
|
219 |
+
length_penalty=length_penalty
|
220 |
+
) # Returns list of tensors
|
221 |
+
|
222 |
+
# --- Decode Generated Tokens ---
|
223 |
+
if not tgt_tokens_list:
|
224 |
+
logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
|
225 |
+
return ["[Decoding Error - Empty Output]"] * n_best
|
226 |
+
|
227 |
+
for tgt_tokens_tensor in tgt_tokens_list:
|
228 |
+
if tgt_tokens_tensor.numel() > 0:
|
229 |
+
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
230 |
+
try:
|
231 |
+
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
232 |
+
translations.append(translation)
|
233 |
+
except Exception as e:
|
234 |
+
logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
|
235 |
+
translations.append("[Decoding Error]")
|
236 |
+
else:
|
237 |
+
translations.append("[Decoding Error - Empty Tensor]")
|
238 |
+
|
239 |
+
# Pad with error messages if fewer than n_best results were generated
|
240 |
+
while len(translations) < n_best:
|
241 |
+
translations.append("[Decoding Error - Fewer Results]")
|
242 |
+
|
243 |
+
return translations
|
244 |
+
|
245 |
+
|
246 |
+
# --- Model/Tokenizer Loading Function (Unchanged) ---
|
247 |
+
def load_model_and_tokenizers():
|
248 |
+
"""Loads tokenizers, config, and model from Hugging Face Hub."""
|
249 |
+
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
250 |
+
if model is not None: # Already loaded
|
251 |
+
logging.info("Model and tokenizers already loaded.")
|
252 |
+
return
|
253 |
+
|
254 |
+
logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
|
255 |
+
try:
|
256 |
+
device = torch.device("cpu")
|
257 |
+
logging.info(f"Using device: {device}")
|
258 |
+
|
259 |
+
# Download files from HF Hub
|
260 |
+
logging.info("Downloading files from Hugging Face Hub...")
|
261 |
+
try:
|
262 |
+
checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME)
|
263 |
+
smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME)
|
264 |
+
iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME)
|
265 |
+
config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME)
|
266 |
+
logging.info("Files downloaded successfully.")
|
267 |
+
except Exception as e:
|
268 |
+
logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True)
|
269 |
+
raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}")
|
270 |
+
|
271 |
+
# Load config
|
272 |
+
logging.info("Loading configuration...")
|
273 |
+
try:
|
274 |
+
with open(config_path, 'r') as f:
|
275 |
+
config = json.load(f)
|
276 |
+
logging.info("Configuration loaded.")
|
277 |
+
# --- Validate essential config keys ---
|
278 |
+
required_keys = [
|
279 |
+
'src_vocab_size', 'tgt_vocab_size', 'emb_size', 'nhead',
|
280 |
+
'ffn_hid_dim', 'num_encoder_layers', 'num_decoder_layers',
|
281 |
+
'dropout', 'max_len', 'bos_token_id', 'eos_token_id', 'pad_token_id'
|
282 |
+
]
|
283 |
+
missing_keys = [key for key in required_keys if key not in config]
|
284 |
+
if missing_keys:
|
285 |
+
raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}")
|
286 |
+
# --- End Validation ---
|
287 |
+
except FileNotFoundError:
|
288 |
+
logging.error(f"Config file not found locally after download attempt: {config_path}")
|
289 |
+
raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo.")
|
290 |
+
except json.JSONDecodeError as e:
|
291 |
+
logging.error(f"Error decoding JSON from config file {config_path}: {e}")
|
292 |
+
raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}")
|
293 |
+
except ValueError as e:
|
294 |
+
logging.error(f"Config validation error: {e}")
|
295 |
+
raise gr.Error(f"Config Error: {e}")
|
296 |
+
|
297 |
+
|
298 |
+
# Load tokenizers
|
299 |
+
logging.info("Loading tokenizers...")
|
300 |
+
try:
|
301 |
+
smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
|
302 |
+
iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
|
303 |
+
logging.info("Tokenizers loaded.")
|
304 |
+
# --- Validate Tokenizer Special Tokens ---
|
305 |
+
# Add more robust checks if necessary
|
306 |
+
if smiles_tokenizer.token_to_id("<pad>") != config['pad_token_id'] or \
|
307 |
+
smiles_tokenizer.token_to_id("<unk>") is None:
|
308 |
+
logging.warning("SMILES tokenizer special tokens might not match config or are missing.")
|
309 |
+
if iupac_tokenizer.token_to_id("<pad>") != config['pad_token_id'] or \
|
310 |
+
iupac_tokenizer.token_to_id("<sos>") != config['bos_token_id'] or \
|
311 |
+
iupac_tokenizer.token_to_id("<eos>") != config['eos_token_id'] or \
|
312 |
+
iupac_tokenizer.token_to_id("<unk>") is None:
|
313 |
+
logging.warning("IUPAC tokenizer special tokens might not match config or are missing.")
|
314 |
+
# --- End Validation ---
|
315 |
+
except Exception as e:
|
316 |
+
logging.error(f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}", exc_info=True)
|
317 |
+
raise gr.Error(f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}")
|
318 |
+
|
319 |
+
# Load model
|
320 |
+
logging.info("Loading model from checkpoint...")
|
321 |
+
try:
|
322 |
+
model = SmilesIupacLitModule.load_from_checkpoint(
|
323 |
+
checkpoint_path,
|
324 |
+
src_vocab_size=config['src_vocab_size'],
|
325 |
+
tgt_vocab_size=config['tgt_vocab_size'],
|
326 |
+
map_location=device,
|
327 |
+
hparams_dict=config,
|
328 |
+
strict=False,
|
329 |
+
device="cpu"
|
330 |
+
)
|
331 |
+
model.to(device)
|
332 |
+
model.eval()
|
333 |
+
model.freeze()
|
334 |
+
logging.info("Model loaded successfully, set to eval mode, frozen, and moved to device.")
|
335 |
+
|
336 |
+
except FileNotFoundError:
|
337 |
+
logging.error(f"Checkpoint file not found locally after download attempt: {checkpoint_path}")
|
338 |
+
raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
|
339 |
+
except Exception as e:
|
340 |
+
logging.error(f"Error loading model from checkpoint {checkpoint_path}: {e}", exc_info=True)
|
341 |
+
if "memory" in str(e).lower():
|
342 |
+
gc.collect()
|
343 |
+
if device == torch.device("cuda"):
|
344 |
+
torch.cuda.empty_cache()
|
345 |
+
raise gr.Error(f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}")
|
346 |
+
|
347 |
+
except gr.Error:
|
348 |
+
raise
|
349 |
+
except Exception as e:
|
350 |
+
logging.error(f"Unexpected error during model/tokenizer loading: {e}", exc_info=True)
|
351 |
+
raise gr.Error(f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}")
|
352 |
+
|
353 |
+
|
354 |
+
# --- Inference Function for Gradio (Unchanged, calls local translate) ---
|
355 |
+
def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
|
356 |
+
"""
|
357 |
+
Performs SMILES to IUPAC translation using the loaded model and beam search.
|
358 |
+
"""
|
359 |
+
global model, smiles_tokenizer, iupac_tokenizer, device, config
|
360 |
+
|
361 |
+
if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
|
362 |
+
error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs."
|
363 |
+
# Ensure n_best is int for range, default to 1 if conversion fails early
|
364 |
+
try: n_best_int = int(n_best)
|
365 |
+
except: n_best_int = 1
|
366 |
+
return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)])
|
367 |
+
|
368 |
+
if not smiles_string or not smiles_string.strip():
|
369 |
+
error_msg = "Error: Please enter a valid SMILES string."
|
370 |
+
try: n_best_int = int(n_best)
|
371 |
+
except: n_best_int = 1
|
372 |
+
return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)])
|
373 |
+
|
374 |
+
smiles_input = smiles_string.strip()
|
375 |
+
try:
|
376 |
+
beam_width = int(beam_width)
|
377 |
+
n_best = int(n_best)
|
378 |
+
length_penalty = float(length_penalty)
|
379 |
+
except ValueError as e:
|
380 |
+
error_msg = f"Error: Invalid input parameter type ({e})."
|
381 |
+
return f"1. {error_msg}" # Cannot determine n_best here
|
382 |
+
|
383 |
+
logging.info(f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})")
|
384 |
+
|
385 |
+
try:
|
386 |
+
# Calls the translate function defined *above in this file*
|
387 |
+
predicted_names = translate(
|
388 |
+
model=model,
|
389 |
+
src_sentence=smiles_input,
|
390 |
+
smiles_tokenizer=smiles_tokenizer,
|
391 |
+
iupac_tokenizer=iupac_tokenizer,
|
392 |
+
device=device,
|
393 |
+
max_len=config['max_len'],
|
394 |
+
sos_idx=config['bos_token_id'],
|
395 |
+
eos_idx=config['eos_token_id'],
|
396 |
+
pad_idx=config['pad_token_id'],
|
397 |
+
beam_width=beam_width,
|
398 |
+
n_best=n_best,
|
399 |
+
length_penalty=length_penalty
|
400 |
+
)
|
401 |
+
logging.info(f"Predictions returned: {predicted_names}")
|
402 |
+
|
403 |
+
if not predicted_names:
|
404 |
+
output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated."
|
405 |
+
else:
|
406 |
+
output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
|
407 |
+
output_text += "\n".join([f"{i+1}. {name}" for i, name in enumerate(predicted_names)])
|
408 |
+
|
409 |
+
return output_text
|
410 |
+
|
411 |
+
except RuntimeError as e:
|
412 |
+
logging.error(f"Runtime error during translation: {e}", exc_info=True)
|
413 |
+
error_msg = f"Runtime Error during translation: {e}"
|
414 |
+
if "memory" in str(e).lower():
|
415 |
+
gc.collect()
|
416 |
+
if device == torch.device("cuda"):
|
417 |
+
torch.cuda.empty_cache()
|
418 |
+
error_msg += " (Potential OOM)"
|
419 |
+
return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)])
|
420 |
+
|
421 |
+
except Exception as e:
|
422 |
+
logging.error(f"Unexpected error during translation: {e}", exc_info=True)
|
423 |
+
error_msg = f"Unexpected Error during translation: {e}"
|
424 |
+
return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)])
|
425 |
+
|
426 |
+
|
427 |
+
# --- Load Model on App Start (Unchanged) ---
|
428 |
+
try:
|
429 |
+
load_model_and_tokenizers()
|
430 |
+
except gr.Error:
|
431 |
+
pass # Error already raised for Gradio UI
|
432 |
+
except Exception as e:
|
433 |
+
logging.error(f"Critical error during initial model loading sequence: {e}", exc_info=True)
|
434 |
+
gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
|
435 |
+
|
436 |
+
|
437 |
+
# --- Create Gradio Interface (Unchanged) ---
|
438 |
+
title = "SMILES to IUPAC Name Translator"
|
439 |
+
description = f"""
|
440 |
+
Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model and beam search decoding.
|
441 |
+
Model repository: <a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>{MODEL_REPO_ID}</a>.
|
442 |
+
Adjust beam search parameters below. Higher beam width explores more possibilities but is slower. Length penalty influences the preference for shorter/longer names.
|
443 |
+
"""
|
444 |
+
|
445 |
+
examples = [
|
446 |
+
["CCO", 5, 3, 0.6],
|
447 |
+
["C1=CC=CC=C1", 5, 3, 0.6],
|
448 |
+
["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
|
449 |
+
["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
|
450 |
+
["CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C", 5, 1, 0.6], # Complex example
|
451 |
+
["INVALID_SMILES", 5, 1, 0.6],
|
452 |
+
]
|
453 |
+
|
454 |
+
smiles_input = gr.Textbox(
|
455 |
+
label="SMILES String",
|
456 |
+
placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
|
457 |
+
lines=1
|
458 |
+
)
|
459 |
+
beam_width_input = gr.Slider(
|
460 |
+
minimum=1,
|
461 |
+
maximum=10,
|
462 |
+
value=5,
|
463 |
+
step=1,
|
464 |
+
label="Beam Width (k)",
|
465 |
+
info="Number of sequences to keep at each decoding step (higher = more exploration, slower)."
|
466 |
+
)
|
467 |
+
n_best_input = gr.Slider(
|
468 |
+
minimum=1,
|
469 |
+
maximum=10,
|
470 |
+
value=3,
|
471 |
+
step=1,
|
472 |
+
label="Number of Results (n_best)",
|
473 |
+
info="How many top-scoring sequences to return (must be <= Beam Width)."
|
474 |
+
)
|
475 |
+
length_penalty_input = gr.Slider(
|
476 |
+
minimum=0.0,
|
477 |
+
maximum=2.0,
|
478 |
+
value=0.6,
|
479 |
+
step=0.1,
|
480 |
+
label="Length Penalty (alpha)",
|
481 |
+
info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty."
|
482 |
+
)
|
483 |
+
output_text = gr.Textbox(
|
484 |
+
label="Predicted IUPAC Name(s)",
|
485 |
+
lines=5,
|
486 |
+
show_copy_button=True
|
487 |
+
)
|
488 |
+
|
489 |
+
iface = gr.Interface(
|
490 |
+
fn=predict_iupac,
|
491 |
+
inputs=[smiles_input, beam_width_input, n_best_input, length_penalty_input],
|
492 |
+
outputs=output_text,
|
493 |
+
title=title,
|
494 |
+
description=description,
|
495 |
+
examples=examples,
|
496 |
+
allow_flagging="never",
|
497 |
+
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
|
498 |
+
article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results."
|
499 |
+
)
|
500 |
+
|
501 |
+
# --- Launch the App (Unchanged) ---
|
502 |
+
if __name__ == "__main__":
|
503 |
+
iface.launch(share=True)
|
enhanced_trainer.py
ADDED
@@ -0,0 +1,1421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Transformer
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
import pytorch_lightning as pl # Import PyTorch Lightning
|
8 |
+
from pytorch_lightning.loggers import WandbLogger # Import WandbLogger
|
9 |
+
from pytorch_lightning.callbacks import (
|
10 |
+
ModelCheckpoint,
|
11 |
+
EarlyStopping,
|
12 |
+
) # Import Callbacks
|
13 |
+
import math
|
14 |
+
import os
|
15 |
+
import pandas as pd
|
16 |
+
from sklearn.model_selection import train_test_split
|
17 |
+
import time
|
18 |
+
import wandb # Import wandb
|
19 |
+
|
20 |
+
|
21 |
+
from tokenizers import (
|
22 |
+
Tokenizer,
|
23 |
+
models,
|
24 |
+
pre_tokenizers,
|
25 |
+
decoders,
|
26 |
+
trainers,
|
27 |
+
)
|
28 |
+
|
29 |
+
import logging
|
30 |
+
import gc
|
31 |
+
|
32 |
+
# --- Basic Logging Setup ---
|
33 |
+
logging.basicConfig(
|
34 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
35 |
+
)
|
36 |
+
|
37 |
+
# --- 1. Configuration & Hyperparameters ---
|
38 |
+
|
39 |
+
# Model Hyperparameters (Scaled up for H100s - ADJUST AS NEEDED based on memory)
|
40 |
+
# Note: BPE might benefit from a slightly larger vocab size than the regex approach
|
41 |
+
SRC_VOCAB_SIZE_ESTIMATE = 10000 # Increased estimate for SMILES BPE
|
42 |
+
TGT_VOCAB_SIZE_ESTIMATE = 14938 # Increased estimate for IUPAC
|
43 |
+
EMB_SIZE = 2048 # Embedding dimension (d_model) - Increased significantly
|
44 |
+
NHEAD = 8 # Number of attention heads (must divide EMB_SIZE) - Increased
|
45 |
+
FFN_HID_DIM = (
|
46 |
+
4096 # Feedforward network hidden dimension (e.g., 4 * EMB_SIZE) - Increased
|
47 |
+
)
|
48 |
+
NUM_ENCODER_LAYERS = 12 # Number of layers in Encoder - Increased
|
49 |
+
NUM_DECODER_LAYERS = 12 # Number of layers in Decoder - Increased
|
50 |
+
DROPOUT = 0.1 # Dropout rate (can sometimes be reduced slightly for larger models)
|
51 |
+
MAX_LEN = 384 # Maximum sequence length (consider increasing if needed/possible)
|
52 |
+
|
53 |
+
# Training Hyperparameters
|
54 |
+
ACCELERATOR = "gpu"
|
55 |
+
DEVICES = 6 # Number of H100 GPUs to use
|
56 |
+
STRATEGY = "ddp" # Distributed Data Parallel Strategy
|
57 |
+
PRECISION = "16-mixed" # Use mixed precision for speed and memory saving on H100s
|
58 |
+
BATCH_SIZE_PER_GPU = 48 # Adjust based on H100 GPU memory (e.g., 32, 48, 64) - Effective BS = BATCH_SIZE_PER_GPU * DEVICES
|
59 |
+
ACCUMULATE_GRAD_BATCHES = (
|
60 |
+
1 # Increase if BATCH_SIZE_PER_GPU needs to be smaller due to memory
|
61 |
+
)
|
62 |
+
NUM_EPOCHS = 50 # Increase for potentially longer training needed for larger models
|
63 |
+
LEARNING_RATE = 5e-5 # Might need adjustment for larger models/batch sizes
|
64 |
+
WEIGHT_DECAY = 1e-2
|
65 |
+
GRAD_CLIP_NORM = 1.0
|
66 |
+
VALIDATION_SPLIT = 0.05 # Use a smaller validation split if the dataset is huge
|
67 |
+
RANDOM_SEED = 42
|
68 |
+
PATIENCE = 5 # Early stopping patience
|
69 |
+
NUM_WORKERS = 8 # Adjust based on CPU cores and system capabilities
|
70 |
+
|
71 |
+
# Special Token Indices
|
72 |
+
PAD_IDX = 0
|
73 |
+
SOS_IDX = 1
|
74 |
+
EOS_IDX = 2
|
75 |
+
UNK_IDX = 3
|
76 |
+
|
77 |
+
# File Paths
|
78 |
+
# *** CHANGED SMILES TOKENIZER FILENAME ***
|
79 |
+
SMILES_TOKENIZER_FILE = "smiles_bytelevel_bpe_tokenizer_scaled.json"
|
80 |
+
IUPAC_TOKENIZER_FILE = "iupac_unigram_tokenizer_scaled.json"
|
81 |
+
INPUT_CSV_FILE = "data_clean.csv" # <--- Your input CSV file path
|
82 |
+
|
83 |
+
# Output files for data splits
|
84 |
+
TRAIN_SMILES_FILE = "train.smi"
|
85 |
+
TRAIN_IUPAC_FILE = "train.iupac"
|
86 |
+
VAL_SMILES_FILE = "val.smi"
|
87 |
+
VAL_IUPAC_FILE = "val.iupac"
|
88 |
+
CHECKPOINT_DIR = "checkpoints" # Directory to save model checkpoints
|
89 |
+
BEST_MODEL_FILENAME = (
|
90 |
+
"smiles-to-iupac-transformer-best" # Filename format for checkpoints
|
91 |
+
)
|
92 |
+
|
93 |
+
# WandB Configuration
|
94 |
+
WANDB_PROJECT = "SMILES-to-IUPAC-Large-BPE" # Updated project name slightly
|
95 |
+
WANDB_ENTITY = (
|
96 |
+
"adrianmirza" # Replace with your WandB entity (username or team name) if desired
|
97 |
+
)
|
98 |
+
WANDB_RUN_NAME = f"transformer_BPE_E{EMB_SIZE}_H{NHEAD}_L{NUM_ENCODER_LAYERS}_BS{BATCH_SIZE_PER_GPU * DEVICES}_LR{LEARNING_RATE}"
|
99 |
+
|
100 |
+
# Store hparams for logging
|
101 |
+
hparams = {
|
102 |
+
"src_tokenizer_type": "ByteLevelBPE", # Added tokenizer type info
|
103 |
+
"tgt_tokenizer_type": "Unigram",
|
104 |
+
"src_vocab_size_estimate": SRC_VOCAB_SIZE_ESTIMATE,
|
105 |
+
"tgt_vocab_size_estimate": TGT_VOCAB_SIZE_ESTIMATE,
|
106 |
+
"emb_size": EMB_SIZE,
|
107 |
+
"nhead": NHEAD,
|
108 |
+
"ffn_hid_dim": FFN_HID_DIM,
|
109 |
+
"num_encoder_layers": NUM_ENCODER_LAYERS,
|
110 |
+
"num_decoder_layers": NUM_DECODER_LAYERS,
|
111 |
+
"dropout": DROPOUT,
|
112 |
+
"max_len": MAX_LEN,
|
113 |
+
"batch_size_per_gpu": BATCH_SIZE_PER_GPU,
|
114 |
+
"effective_batch_size": BATCH_SIZE_PER_GPU * DEVICES * ACCUMULATE_GRAD_BATCHES,
|
115 |
+
"num_epochs": NUM_EPOCHS,
|
116 |
+
"learning_rate": LEARNING_RATE,
|
117 |
+
"weight_decay": WEIGHT_DECAY,
|
118 |
+
"grad_clip_norm": GRAD_CLIP_NORM,
|
119 |
+
"validation_split": VALIDATION_SPLIT,
|
120 |
+
"random_seed": RANDOM_SEED,
|
121 |
+
"patience": PATIENCE,
|
122 |
+
"precision": PRECISION,
|
123 |
+
"gpus": DEVICES,
|
124 |
+
"strategy": STRATEGY,
|
125 |
+
"num_workers": NUM_WORKERS,
|
126 |
+
}
|
127 |
+
|
128 |
+
# --- 2. Token izers (Modified SMILES Tokenizer) ---
|
129 |
+
|
130 |
+
|
131 |
+
# --- 2.a SMILES ByteLevel BPE Tokenizer (Replaced WordLevel Regex) ---
|
132 |
+
def get_smiles_tokenizer(
|
133 |
+
train_files=None,
|
134 |
+
vocab_size=30000,
|
135 |
+
min_frequency=2,
|
136 |
+
tokenizer_path=SMILES_TOKENIZER_FILE,
|
137 |
+
):
|
138 |
+
"""Creates or loads a Byte-Level BPE tokenizer for SMILES."""
|
139 |
+
if os.path.exists(tokenizer_path):
|
140 |
+
logging.info(f"Loading existing SMILES tokenizer from {tokenizer_path}")
|
141 |
+
try:
|
142 |
+
tokenizer = Tokenizer.from_file(tokenizer_path)
|
143 |
+
# Verify special tokens after loading
|
144 |
+
if (
|
145 |
+
tokenizer.token_to_id("<pad>") != PAD_IDX
|
146 |
+
or tokenizer.token_to_id("<sos>") != SOS_IDX
|
147 |
+
or tokenizer.token_to_id("<eos>") != EOS_IDX
|
148 |
+
or tokenizer.token_to_id("<unk>") != UNK_IDX
|
149 |
+
):
|
150 |
+
logging.warning(
|
151 |
+
"Special token ID mismatch after loading SMILES tokenizer. Re-check config."
|
152 |
+
)
|
153 |
+
# Check if it's actually a BPE model (basic check)
|
154 |
+
if not isinstance(tokenizer.model, models.BPE):
|
155 |
+
logging.warning(
|
156 |
+
f"Loaded tokenizer from {tokenizer_path} is not a BPE model. Retraining."
|
157 |
+
)
|
158 |
+
raise TypeError("Incorrect tokenizer model type loaded.")
|
159 |
+
return tokenizer
|
160 |
+
except Exception as e:
|
161 |
+
logging.error(f"Failed to load SMILES tokenizer: {e}. Retraining...")
|
162 |
+
|
163 |
+
logging.info("Creating and training SMILES Byte-Level BPE tokenizer...")
|
164 |
+
# Use BPE model
|
165 |
+
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
|
166 |
+
|
167 |
+
# Use ByteLevel pre-tokenizer - this handles any character sequence
|
168 |
+
# add_prefix_space=False is generally suitable for SMILES as it doesn't rely on spaces
|
169 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
170 |
+
# Use ByteLevel decoder
|
171 |
+
tokenizer.decoder = decoders.ByteLevel()
|
172 |
+
|
173 |
+
special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
|
174 |
+
# Use BpeTrainer
|
175 |
+
trainer = trainers.BpeTrainer(
|
176 |
+
vocab_size=vocab_size,
|
177 |
+
min_frequency=min_frequency,
|
178 |
+
special_tokens=special_tokens,
|
179 |
+
# BPE specific options can be added here if needed, e.g.:
|
180 |
+
# initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), # Usually inferred
|
181 |
+
# show_progress=True,
|
182 |
+
)
|
183 |
+
|
184 |
+
if train_files and all(os.path.exists(f) for f in train_files):
|
185 |
+
logging.info(f"Training SMILES BPE tokenizer on: {train_files}")
|
186 |
+
tokenizer.train(files=train_files, trainer=trainer)
|
187 |
+
logging.info(
|
188 |
+
f"SMILES BPE tokenizer trained. Final Vocab size: {tokenizer.get_vocab_size()}"
|
189 |
+
)
|
190 |
+
# Verify special token IDs after training
|
191 |
+
if (
|
192 |
+
tokenizer.token_to_id("<pad>") != PAD_IDX
|
193 |
+
or tokenizer.token_to_id("<sos>") != SOS_IDX
|
194 |
+
or tokenizer.token_to_id("<eos>") != EOS_IDX
|
195 |
+
or tokenizer.token_to_id("<unk>") != UNK_IDX
|
196 |
+
):
|
197 |
+
logging.warning(
|
198 |
+
"Special token ID mismatch after training SMILES BPE tokenizer. Check trainer setup."
|
199 |
+
)
|
200 |
+
try:
|
201 |
+
tokenizer.save(tokenizer_path)
|
202 |
+
logging.info(f"SMILES BPE tokenizer saved to {tokenizer_path}")
|
203 |
+
except Exception as e:
|
204 |
+
logging.error(f"Failed to save SMILES BPE tokenizer: {e}")
|
205 |
+
else:
|
206 |
+
logging.error(
|
207 |
+
"Training files not provided or not found for SMILES tokenizer. Cannot train."
|
208 |
+
)
|
209 |
+
# Manually add special tokens if training fails, so basic encoding/decoding might work
|
210 |
+
tokenizer.add_special_tokens(special_tokens)
|
211 |
+
|
212 |
+
return tokenizer
|
213 |
+
|
214 |
+
|
215 |
+
# --- 2.b IUPAC Unigram Tokenizer (No changes needed here) ---
|
216 |
+
def get_iupac_tokenizer(
|
217 |
+
train_files=None,
|
218 |
+
vocab_size=30000,
|
219 |
+
min_frequency=2,
|
220 |
+
tokenizer_path=IUPAC_TOKENIZER_FILE,
|
221 |
+
):
|
222 |
+
"""Creates or loads a Unigram tokenizer for IUPAC names."""
|
223 |
+
if os.path.exists(tokenizer_path):
|
224 |
+
logging.info(f"Loading existing IUPAC tokenizer from {tokenizer_path}")
|
225 |
+
try:
|
226 |
+
tokenizer = Tokenizer.from_file(tokenizer_path)
|
227 |
+
if (
|
228 |
+
tokenizer.token_to_id("<pad>") != PAD_IDX
|
229 |
+
or tokenizer.token_to_id("<sos>") != SOS_IDX
|
230 |
+
or tokenizer.token_to_id("<eos>") != EOS_IDX
|
231 |
+
or tokenizer.token_to_id("<unk>") != UNK_IDX
|
232 |
+
):
|
233 |
+
logging.warning(
|
234 |
+
"Special token ID mismatch after loading IUPAC tokenizer. Re-check config."
|
235 |
+
)
|
236 |
+
return tokenizer
|
237 |
+
except Exception as e:
|
238 |
+
logging.error(f"Failed to load IUPAC tokenizer: {e}. Retraining...")
|
239 |
+
|
240 |
+
logging.info("Creating and training IUPAC Unigram tokenizer...")
|
241 |
+
tokenizer = Tokenizer(models.Unigram())
|
242 |
+
# Using Sequence of pre-tokenizers for IUPAC is reasonable
|
243 |
+
pre_tokenizer_list = [
|
244 |
+
pre_tokenizers.WhitespaceSplit(), # Split by whitespace first
|
245 |
+
pre_tokenizers.Punctuation(), # Split punctuation
|
246 |
+
pre_tokenizers.Digits(individual_digits=True), # Split digits
|
247 |
+
]
|
248 |
+
# Consider adding Metaspace if Unigram struggles with word boundaries after splits
|
249 |
+
# tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() # Alternative
|
250 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(pre_tokenizer_list)
|
251 |
+
tokenizer.decoder = (
|
252 |
+
decoders.Metaspace()
|
253 |
+
) # Metaspace decoder often works well with Unigram/BPE
|
254 |
+
special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
|
255 |
+
trainer = trainers.UnigramTrainer(
|
256 |
+
vocab_size=vocab_size,
|
257 |
+
special_tokens=special_tokens,
|
258 |
+
unk_token="<unk>",
|
259 |
+
# Unigram specific options can be added here
|
260 |
+
# shrinking_factor=0.75,
|
261 |
+
# n_sub_iterations=2,
|
262 |
+
)
|
263 |
+
|
264 |
+
if train_files and all(os.path.exists(f) for f in train_files):
|
265 |
+
logging.info(f"Training IUPAC tokenizer on: {train_files}")
|
266 |
+
tokenizer.train(files=train_files, trainer=trainer)
|
267 |
+
logging.info(
|
268 |
+
f"IUPAC tokenizer trained. Final Vocab size: {tokenizer.get_vocab_size()}"
|
269 |
+
)
|
270 |
+
# Verify special token IDs after training
|
271 |
+
if (
|
272 |
+
tokenizer.token_to_id("<pad>") != PAD_IDX
|
273 |
+
or tokenizer.token_to_id("<sos>") != SOS_IDX
|
274 |
+
or tokenizer.token_to_id("<eos>") != EOS_IDX
|
275 |
+
or tokenizer.token_to_id("<unk>") != UNK_IDX
|
276 |
+
):
|
277 |
+
logging.warning(
|
278 |
+
"Special token ID mismatch after training IUPAC tokenizer. Check trainer setup."
|
279 |
+
)
|
280 |
+
try:
|
281 |
+
tokenizer.save(tokenizer_path)
|
282 |
+
logging.info(f"IUPAC tokenizer saved to {tokenizer_path}")
|
283 |
+
except Exception as e:
|
284 |
+
logging.error(f"Failed to save IUPAC tokenizer: {e}")
|
285 |
+
else:
|
286 |
+
logging.error(
|
287 |
+
"Training files not provided or not found for IUPAC tokenizer. Cannot train."
|
288 |
+
)
|
289 |
+
tokenizer.add_special_tokens(special_tokens)
|
290 |
+
|
291 |
+
return tokenizer
|
292 |
+
|
293 |
+
|
294 |
+
# --- 3. Model Definition (No changes needed) ---
|
295 |
+
class PositionalEncoding(nn.Module):
|
296 |
+
"""Injects positional information into the input embeddings."""
|
297 |
+
|
298 |
+
def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
|
299 |
+
super().__init__()
|
300 |
+
den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
|
301 |
+
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
|
302 |
+
pos_embedding = torch.zeros((maxlen, emb_size))
|
303 |
+
pos_embedding[:, 0::2] = torch.sin(pos * den)
|
304 |
+
pos_embedding[:, 1::2] = torch.cos(pos * den)
|
305 |
+
pos_embedding = pos_embedding.unsqueeze(
|
306 |
+
0
|
307 |
+
) # Add batch dimension for broadcasting
|
308 |
+
self.dropout = nn.Dropout(dropout)
|
309 |
+
self.register_buffer(
|
310 |
+
"pos_embedding", pos_embedding
|
311 |
+
) # Shape [1, maxlen, emb_size]
|
312 |
+
|
313 |
+
def forward(self, token_embedding: torch.Tensor):
|
314 |
+
# token_embedding: Expected shape [batch_size, seq_len, emb_size]
|
315 |
+
seq_len = token_embedding.size(1)
|
316 |
+
# Slicing pos_embedding: [1, seq_len, emb_size]
|
317 |
+
# Handle cases where seq_len might exceed buffer's maxlen during inference/edge cases
|
318 |
+
if seq_len > self.pos_embedding.size(1):
|
319 |
+
logging.warning(
|
320 |
+
f"Input sequence length ({seq_len}) exceeds PositionalEncoding maxlen ({self.pos_embedding.size(1)}). Truncating positional encoding."
|
321 |
+
)
|
322 |
+
pos_to_add = self.pos_embedding[:, : self.pos_embedding.size(1), :]
|
323 |
+
# Pad token_embedding if needed? Or error out? For now, just use available encoding.
|
324 |
+
# This scenario shouldn't happen if MAX_LEN config is respected.
|
325 |
+
output = token_embedding[:, : self.pos_embedding.size(1), :] + pos_to_add
|
326 |
+
else:
|
327 |
+
pos_to_add = self.pos_embedding[:, :seq_len, :]
|
328 |
+
output = token_embedding + pos_to_add
|
329 |
+
|
330 |
+
return self.dropout(output)
|
331 |
+
|
332 |
+
|
333 |
+
class TokenEmbedding(nn.Module):
|
334 |
+
"""Converts token indices to embeddings."""
|
335 |
+
|
336 |
+
def __init__(self, vocab_size: int, emb_size):
|
337 |
+
super().__init__()
|
338 |
+
self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=PAD_IDX)
|
339 |
+
self.emb_size = emb_size
|
340 |
+
|
341 |
+
def forward(self, tokens: torch.Tensor):
|
342 |
+
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
|
343 |
+
|
344 |
+
|
345 |
+
class Seq2SeqTransformer(nn.Module):
|
346 |
+
"""The main Encoder-Decoder Transformer model."""
|
347 |
+
|
348 |
+
def __init__(
|
349 |
+
self,
|
350 |
+
num_encoder_layers: int,
|
351 |
+
num_decoder_layers: int,
|
352 |
+
emb_size: int,
|
353 |
+
nhead: int,
|
354 |
+
src_vocab_size: int,
|
355 |
+
tgt_vocab_size: int,
|
356 |
+
dim_feedforward: int,
|
357 |
+
dropout: float = 0.1,
|
358 |
+
max_len: int = MAX_LEN,
|
359 |
+
): # Use MAX_LEN from config
|
360 |
+
super().__init__()
|
361 |
+
|
362 |
+
if emb_size % nhead != 0:
|
363 |
+
raise ValueError(
|
364 |
+
f"Embedding size ({emb_size}) must be divisible by the number of heads ({nhead})"
|
365 |
+
)
|
366 |
+
|
367 |
+
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
|
368 |
+
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
|
369 |
+
|
370 |
+
# Ensure PositionalEncoding maxlen is sufficient
|
371 |
+
pe_maxlen = max(
|
372 |
+
max_len, 5000
|
373 |
+
) # Use config MAX_LEN or default 5000, whichever is larger
|
374 |
+
self.positional_encoding = PositionalEncoding(
|
375 |
+
emb_size, dropout=dropout, maxlen=pe_maxlen
|
376 |
+
)
|
377 |
+
|
378 |
+
self.transformer = Transformer(
|
379 |
+
d_model=emb_size,
|
380 |
+
nhead=nhead,
|
381 |
+
num_encoder_layers=num_encoder_layers,
|
382 |
+
num_decoder_layers=num_decoder_layers,
|
383 |
+
dim_feedforward=dim_feedforward,
|
384 |
+
dropout=dropout,
|
385 |
+
batch_first=True,
|
386 |
+
) # Use batch_first=True
|
387 |
+
|
388 |
+
self.generator = nn.Linear(emb_size, tgt_vocab_size)
|
389 |
+
self._init_weights()
|
390 |
+
|
391 |
+
def _init_weights(self):
|
392 |
+
for p in self.parameters():
|
393 |
+
if p.dim() > 1:
|
394 |
+
nn.init.xavier_uniform_(p)
|
395 |
+
|
396 |
+
def forward(
|
397 |
+
self,
|
398 |
+
src: torch.Tensor, # Input sequence (batch_size, src_len)
|
399 |
+
trg: torch.Tensor, # Target sequence (batch_size, tgt_len)
|
400 |
+
tgt_mask: torch.Tensor, # Target causal mask (tgt_len, tgt_len)
|
401 |
+
src_padding_mask: torch.Tensor, # Source padding mask (batch_size, src_len)
|
402 |
+
tgt_padding_mask: torch.Tensor, # Target padding mask (batch_size, tgt_len)
|
403 |
+
memory_key_padding_mask: torch.Tensor,
|
404 |
+
): # Memory padding mask (batch_size, src_len)
|
405 |
+
# --- Ensure masks have correct dtype and device ---
|
406 |
+
# Pytorch Transformer expects boolean masks where True indicates masking
|
407 |
+
src_padding_mask = src_padding_mask.to(src.device)
|
408 |
+
tgt_padding_mask = tgt_padding_mask.to(trg.device)
|
409 |
+
memory_key_padding_mask = memory_key_padding_mask.to(src.device)
|
410 |
+
# tgt_mask needs to be float for '-inf' filling, keep on target device
|
411 |
+
tgt_mask = tgt_mask.to(trg.device)
|
412 |
+
|
413 |
+
src_emb = self.positional_encoding(
|
414 |
+
self.src_tok_emb(src)
|
415 |
+
) # [batch, src_len, dim]
|
416 |
+
tgt_emb = self.positional_encoding(
|
417 |
+
self.tgt_tok_emb(trg)
|
418 |
+
) # [batch, tgt_len, dim]
|
419 |
+
|
420 |
+
outs = self.transformer(
|
421 |
+
src=src_emb,
|
422 |
+
tgt=tgt_emb,
|
423 |
+
src_mask=None, # Not typically needed for encoder unless custom masking
|
424 |
+
tgt_mask=tgt_mask, # Causal mask for decoder self-attn
|
425 |
+
memory_mask=None, # Not typically needed unless masking specific memory parts
|
426 |
+
src_key_padding_mask=src_padding_mask, # Mask padding in src K,V
|
427 |
+
tgt_key_padding_mask=tgt_padding_mask, # Mask padding in tgt Q
|
428 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
429 |
+
) # Mask padding in memory K,V for cross-attn
|
430 |
+
# outs: [batch_size, tgt_len, emb_size]
|
431 |
+
return self.generator(outs) # [batch_size, tgt_len, tgt_vocab_size]
|
432 |
+
|
433 |
+
def encode(self, src: torch.Tensor, src_padding_mask: torch.Tensor):
|
434 |
+
src_padding_mask = src_padding_mask.to(
|
435 |
+
src.device
|
436 |
+
) # Ensure mask is on correct device
|
437 |
+
src_emb = self.positional_encoding(
|
438 |
+
self.src_tok_emb(src)
|
439 |
+
) # [batch, src_len, dim]
|
440 |
+
memory = self.transformer.encoder(
|
441 |
+
src_emb, mask=None, src_key_padding_mask=src_padding_mask
|
442 |
+
)
|
443 |
+
return memory # Returns memory: [batch_size, src_len, emb_size]
|
444 |
+
|
445 |
+
def decode(
|
446 |
+
self,
|
447 |
+
tgt: torch.Tensor,
|
448 |
+
memory: torch.Tensor,
|
449 |
+
tgt_mask: torch.Tensor,
|
450 |
+
tgt_padding_mask: torch.Tensor,
|
451 |
+
memory_key_padding_mask: torch.Tensor,
|
452 |
+
):
|
453 |
+
# Ensure masks are on correct device
|
454 |
+
tgt_mask = tgt_mask.to(tgt.device)
|
455 |
+
tgt_padding_mask = tgt_padding_mask.to(tgt.device)
|
456 |
+
memory_key_padding_mask = memory_key_padding_mask.to(memory.device)
|
457 |
+
|
458 |
+
tgt_emb = self.positional_encoding(
|
459 |
+
self.tgt_tok_emb(tgt)
|
460 |
+
) # [batch, tgt_len, dim]
|
461 |
+
output = self.transformer.decoder(
|
462 |
+
tgt=tgt_emb,
|
463 |
+
memory=memory,
|
464 |
+
tgt_mask=tgt_mask,
|
465 |
+
memory_mask=None,
|
466 |
+
tgt_key_padding_mask=tgt_padding_mask,
|
467 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
468 |
+
)
|
469 |
+
return output # Returns decoder output: [batch_size, tgt_len, emb_size]
|
470 |
+
|
471 |
+
|
472 |
+
# --- Helper function for mask creation (No changes needed) ---
|
473 |
+
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
|
474 |
+
"""Generates an upper-triangular matrix for causal masking."""
|
475 |
+
mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
|
476 |
+
mask = (
|
477 |
+
mask.float()
|
478 |
+
.masked_fill(mask == 0, float("-inf"))
|
479 |
+
.masked_fill(mask == 1, float(0.0))
|
480 |
+
)
|
481 |
+
return mask # Shape [sz, sz]
|
482 |
+
|
483 |
+
|
484 |
+
def create_masks(
|
485 |
+
src: torch.Tensor, tgt: torch.Tensor, pad_idx: int, device: torch.device
|
486 |
+
):
|
487 |
+
"""
|
488 |
+
Creates all necessary masks for the Transformer model.
|
489 |
+
Assumes src and tgt are inputs to the forward pass (tgt includes SOS, excludes EOS).
|
490 |
+
Returns boolean masks where True indicates the position should be masked (ignored).
|
491 |
+
"""
|
492 |
+
src_seq_len = src.shape[1]
|
493 |
+
tgt_seq_len = tgt.shape[1]
|
494 |
+
|
495 |
+
# Causal mask for decoder self-attention (float mask for PyTorch Transformer)
|
496 |
+
tgt_mask = generate_square_subsequent_mask(
|
497 |
+
tgt_seq_len, device
|
498 |
+
) # [tgt_len, tgt_len]
|
499 |
+
|
500 |
+
# Padding masks (boolean, True where padded)
|
501 |
+
src_padding_mask = src == pad_idx # [batch_size, src_len]
|
502 |
+
tgt_padding_mask = tgt == pad_idx # [batch_size, tgt_len]
|
503 |
+
memory_key_padding_mask = (
|
504 |
+
src_padding_mask # Used in decoder cross-attention [batch_size, src_len]
|
505 |
+
)
|
506 |
+
|
507 |
+
return tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask
|
508 |
+
|
509 |
+
|
510 |
+
# --- 4. Data Handling (Dataset and Collate Function - No changes needed) ---
|
511 |
+
class SmilesIupacDataset(Dataset):
|
512 |
+
"""Dataset class for SMILES-IUPAC pairs, reading from pre-split files."""
|
513 |
+
|
514 |
+
def __init__(self, smiles_file: str, iupac_file: str):
|
515 |
+
logging.info(f"Loading data from {smiles_file} and {iupac_file}")
|
516 |
+
try:
|
517 |
+
with open(smiles_file, "r", encoding="utf-8") as f_smi:
|
518 |
+
self.smiles = [line.strip() for line in f_smi if line.strip()]
|
519 |
+
with open(iupac_file, "r", encoding="utf-8") as f_iupac:
|
520 |
+
self.iupac = [line.strip() for line in f_iupac if line.strip()]
|
521 |
+
|
522 |
+
if len(self.smiles) != len(self.iupac):
|
523 |
+
logging.warning(
|
524 |
+
f"Mismatch in number of lines: {smiles_file} ({len(self.smiles)}) vs {iupac_file} ({len(self.iupac)}). Trimming."
|
525 |
+
)
|
526 |
+
min_len = min(len(self.smiles), len(self.iupac))
|
527 |
+
self.smiles = self.smiles[:min_len]
|
528 |
+
self.iupac = self.iupac[:min_len]
|
529 |
+
|
530 |
+
logging.info(
|
531 |
+
f"Loaded {len(self.smiles)} pairs from {smiles_file}/{iupac_file}."
|
532 |
+
)
|
533 |
+
if len(self.smiles) == 0:
|
534 |
+
logging.warning(f"Loaded 0 data pairs. Check files.")
|
535 |
+
|
536 |
+
except FileNotFoundError:
|
537 |
+
logging.error(
|
538 |
+
f"Error: One or both files not found: {smiles_file}, {iupac_file}"
|
539 |
+
)
|
540 |
+
raise
|
541 |
+
except Exception as e:
|
542 |
+
logging.error(f"Error loading data: {e}")
|
543 |
+
raise
|
544 |
+
|
545 |
+
def __len__(self):
|
546 |
+
return len(self.smiles)
|
547 |
+
|
548 |
+
def __getitem__(self, idx):
|
549 |
+
return self.smiles[idx], self.iupac[idx]
|
550 |
+
|
551 |
+
|
552 |
+
def collate_fn(
|
553 |
+
batch, smiles_tokenizer, iupac_tokenizer, pad_idx, sos_idx, eos_idx, max_len
|
554 |
+
):
|
555 |
+
"""Collates data samples into batches."""
|
556 |
+
src_batch, tgt_batch = [], []
|
557 |
+
skipped_count = 0
|
558 |
+
for src_sample, tgt_sample in batch:
|
559 |
+
try:
|
560 |
+
# Encode source (SMILES)
|
561 |
+
src_encoded = smiles_tokenizer.encode(src_sample)
|
562 |
+
# Truncate source if needed (including potential special tokens if added by encode)
|
563 |
+
src_ids = src_encoded.ids[:max_len]
|
564 |
+
if not src_ids: # Skip if encoding results in empty sequence
|
565 |
+
skipped_count += 1
|
566 |
+
continue
|
567 |
+
src_tensor = torch.tensor(src_ids, dtype=torch.long)
|
568 |
+
|
569 |
+
# Encode target (IUPAC)
|
570 |
+
tgt_encoded = iupac_tokenizer.encode(tgt_sample)
|
571 |
+
# Truncate target allowing space for SOS and EOS
|
572 |
+
tgt_ids = tgt_encoded.ids[: max_len - 2]
|
573 |
+
if (
|
574 |
+
not tgt_ids
|
575 |
+
): # Skip if encoding results in empty sequence (after truncation)
|
576 |
+
skipped_count += 1
|
577 |
+
continue
|
578 |
+
# Add SOS and EOS tokens
|
579 |
+
tgt_tensor = torch.tensor([sos_idx] + tgt_ids + [eos_idx], dtype=torch.long)
|
580 |
+
|
581 |
+
src_batch.append(src_tensor)
|
582 |
+
tgt_batch.append(tgt_tensor)
|
583 |
+
except Exception as e:
|
584 |
+
# Log infrequent warnings for skipping
|
585 |
+
# if skipped_count < 5: # Log only the first few skips per batch
|
586 |
+
# logging.warning(f"Skipping sample due to error during tokenization/tensor creation: {e}. SMILES: '{src_sample[:50]}...', IUPAC: '{tgt_sample[:50]}...'")
|
587 |
+
skipped_count += 1
|
588 |
+
continue
|
589 |
+
|
590 |
+
# if skipped_count > 0:
|
591 |
+
# logging.debug(f"Skipped {skipped_count} samples in this batch during collation.")
|
592 |
+
|
593 |
+
if not src_batch or not tgt_batch:
|
594 |
+
# Return empty tensors if the whole batch was skipped
|
595 |
+
return torch.tensor([]), torch.tensor([])
|
596 |
+
|
597 |
+
try:
|
598 |
+
# Pad sequences
|
599 |
+
src_batch_padded = pad_sequence(
|
600 |
+
src_batch, batch_first=True, padding_value=pad_idx
|
601 |
+
)
|
602 |
+
tgt_batch_padded = pad_sequence(
|
603 |
+
tgt_batch, batch_first=True, padding_value=pad_idx
|
604 |
+
)
|
605 |
+
except Exception as e:
|
606 |
+
logging.error(
|
607 |
+
f"Error during padding: {e}. Src lengths: {[len(s) for s in src_batch]}, Tgt lengths: {[len(t) for t in tgt_batch]}"
|
608 |
+
)
|
609 |
+
# Return empty tensors on padding error
|
610 |
+
return torch.tensor([]), torch.tensor([])
|
611 |
+
|
612 |
+
return src_batch_padded, tgt_batch_padded
|
613 |
+
|
614 |
+
|
615 |
+
# --- 5. PyTorch Lightning Module (No changes needed) ---
|
616 |
+
class SmilesIupacLitModule(pl.LightningModule):
|
617 |
+
def __init__(
|
618 |
+
self, src_vocab_size: int, tgt_vocab_size: int, hparams_dict: dict
|
619 |
+
): # Pass hparams dictionary
|
620 |
+
super().__init__()
|
621 |
+
# Use save_hyperparameters() to automatically save args to self.hparams
|
622 |
+
# and make them accessible in checkpoints and loggers
|
623 |
+
self.save_hyperparameters(hparams_dict)
|
624 |
+
|
625 |
+
self.model = Seq2SeqTransformer(
|
626 |
+
num_encoder_layers=self.hparams.num_encoder_layers,
|
627 |
+
num_decoder_layers=self.hparams.num_decoder_layers,
|
628 |
+
emb_size=self.hparams.emb_size,
|
629 |
+
nhead=self.hparams.nhead,
|
630 |
+
src_vocab_size=src_vocab_size, # Pass actual vocab size
|
631 |
+
tgt_vocab_size=tgt_vocab_size, # Pass actual vocab size
|
632 |
+
dim_feedforward=self.hparams.ffn_hid_dim,
|
633 |
+
dropout=self.hparams.dropout,
|
634 |
+
max_len=self.hparams.max_len, # Pass max_len here
|
635 |
+
)
|
636 |
+
|
637 |
+
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
|
638 |
+
|
639 |
+
# --- Count Parameters --- (Done once at initialization)
|
640 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
641 |
+
trainable_params = sum(
|
642 |
+
p.numel() for p in self.model.parameters() if p.requires_grad
|
643 |
+
)
|
644 |
+
logging.info(f"Model Initialized:")
|
645 |
+
logging.info(f" Total Parameters: {total_params / 1_000_000:.2f} M")
|
646 |
+
logging.info(f" Trainable Parameters: {trainable_params / 1_000_000:.2f} M")
|
647 |
+
# Log params to wandb hparams if logger is available
|
648 |
+
# self.hparams are automatically logged by WandbLogger if passed to Trainer
|
649 |
+
# We can add them explicitly if needed, but save_hyperparameters usually handles it.
|
650 |
+
self.hparams.total_params_M = round(total_params / 1_000_000, 2)
|
651 |
+
self.hparams.trainable_params_M = round(trainable_params / 1_000_000, 2)
|
652 |
+
|
653 |
+
def forward(self, src, tgt):
|
654 |
+
# This is the main forward pass used for inference/prediction if needed
|
655 |
+
# For training/validation, we call the model directly in step methods
|
656 |
+
# to handle mask creation explicitly.
|
657 |
+
tgt_input = tgt[:, :-1] # Prepare target input (remove EOS)
|
658 |
+
tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = (
|
659 |
+
create_masks(
|
660 |
+
src,
|
661 |
+
tgt_input,
|
662 |
+
PAD_IDX,
|
663 |
+
self.device, # Use self.device provided by Lightning
|
664 |
+
)
|
665 |
+
)
|
666 |
+
logits = self.model(
|
667 |
+
src,
|
668 |
+
tgt_input,
|
669 |
+
tgt_mask,
|
670 |
+
src_padding_mask,
|
671 |
+
tgt_padding_mask,
|
672 |
+
memory_key_padding_mask,
|
673 |
+
)
|
674 |
+
return logits
|
675 |
+
|
676 |
+
def training_step(self, batch, batch_idx):
|
677 |
+
src, tgt = batch
|
678 |
+
if src.numel() == 0 or tgt.numel() == 0:
|
679 |
+
# logging.debug(f"Skipping empty batch {batch_idx} in training.")
|
680 |
+
return None # Skip empty batches
|
681 |
+
|
682 |
+
tgt_input = tgt[:, :-1] # Exclude EOS for input
|
683 |
+
tgt_out = tgt[:, 1:] # Exclude SOS for target labels
|
684 |
+
|
685 |
+
# Create masks on the current device
|
686 |
+
tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = (
|
687 |
+
create_masks(src, tgt_input, PAD_IDX, self.device)
|
688 |
+
)
|
689 |
+
|
690 |
+
try:
|
691 |
+
logits = self.model(
|
692 |
+
src=src,
|
693 |
+
trg=tgt_input,
|
694 |
+
tgt_mask=tgt_mask,
|
695 |
+
src_padding_mask=src_padding_mask,
|
696 |
+
tgt_padding_mask=tgt_padding_mask,
|
697 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
698 |
+
)
|
699 |
+
# logits: [batch_size, tgt_len-1, tgt_vocab_size]
|
700 |
+
|
701 |
+
# Calculate loss
|
702 |
+
# Reshape logits to [batch_size * (tgt_len-1), tgt_vocab_size]
|
703 |
+
# Reshape tgt_out to [batch_size * (tgt_len-1)]
|
704 |
+
loss = self.criterion(
|
705 |
+
logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)
|
706 |
+
)
|
707 |
+
|
708 |
+
# Check for NaN/Inf loss (important with mixed precision)
|
709 |
+
if not torch.isfinite(loss):
|
710 |
+
logging.warning(
|
711 |
+
f"Non-finite loss encountered in training step {batch_idx}: {loss.item()}. Skipping update."
|
712 |
+
)
|
713 |
+
# Manually skip optimizer step if using manual optimization,
|
714 |
+
# otherwise returning None might be sufficient for automatic opt.
|
715 |
+
return None # Returning None should prevent optimizer step
|
716 |
+
|
717 |
+
# Log training loss
|
718 |
+
# sync_dist=True is important for DDP to average loss across GPUs
|
719 |
+
self.log(
|
720 |
+
"train_loss",
|
721 |
+
loss,
|
722 |
+
on_step=True,
|
723 |
+
on_epoch=True,
|
724 |
+
prog_bar=True,
|
725 |
+
logger=True,
|
726 |
+
sync_dist=True,
|
727 |
+
batch_size=src.size(0),
|
728 |
+
)
|
729 |
+
|
730 |
+
return loss
|
731 |
+
|
732 |
+
except RuntimeError as e:
|
733 |
+
if "CUDA out of memory" in str(e):
|
734 |
+
logging.warning(
|
735 |
+
f"CUDA OOM error during training step {batch_idx} with shape src: {src.shape}, tgt: {tgt.shape}. Skipping batch."
|
736 |
+
)
|
737 |
+
gc.collect()
|
738 |
+
torch.cuda.empty_cache()
|
739 |
+
return None # Skip update
|
740 |
+
else:
|
741 |
+
logging.error(f"Runtime error during training step {batch_idx}: {e}")
|
742 |
+
# Optionally log shapes for debugging other runtime errors
|
743 |
+
logging.error(f"Shapes - src: {src.shape}, tgt: {tgt.shape}")
|
744 |
+
return None # Skip update
|
745 |
+
|
746 |
+
def validation_step(self, batch, batch_idx):
|
747 |
+
src, tgt = batch
|
748 |
+
if src.numel() == 0 or tgt.numel() == 0:
|
749 |
+
# logging.debug(f"Skipping empty batch {batch_idx} in validation.")
|
750 |
+
return None
|
751 |
+
|
752 |
+
tgt_input = tgt[:, :-1]
|
753 |
+
tgt_out = tgt[:, 1:]
|
754 |
+
|
755 |
+
tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = (
|
756 |
+
create_masks(src, tgt_input, PAD_IDX, self.device)
|
757 |
+
)
|
758 |
+
|
759 |
+
try:
|
760 |
+
logits = self.model(
|
761 |
+
src,
|
762 |
+
tgt_input,
|
763 |
+
tgt_mask,
|
764 |
+
src_padding_mask,
|
765 |
+
tgt_padding_mask,
|
766 |
+
memory_key_padding_mask,
|
767 |
+
)
|
768 |
+
loss = self.criterion(
|
769 |
+
logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)
|
770 |
+
)
|
771 |
+
|
772 |
+
if torch.isfinite(loss):
|
773 |
+
# Log validation loss (accumulated across batches and synced across GPUs at epoch end)
|
774 |
+
# sync_dist=True ensures correct aggregation in DDP
|
775 |
+
self.log(
|
776 |
+
"val_loss",
|
777 |
+
loss,
|
778 |
+
on_step=False,
|
779 |
+
on_epoch=True,
|
780 |
+
prog_bar=True,
|
781 |
+
logger=True,
|
782 |
+
sync_dist=True,
|
783 |
+
batch_size=src.size(0),
|
784 |
+
)
|
785 |
+
else:
|
786 |
+
logging.warning(
|
787 |
+
f"Non-finite loss encountered during validation step {batch_idx}: {loss.item()}."
|
788 |
+
)
|
789 |
+
# PTL aggregates logged values automatically for the epoch
|
790 |
+
# Returning the loss value itself isn't strictly necessary when using self.log
|
791 |
+
# return loss
|
792 |
+
|
793 |
+
except RuntimeError as e:
|
794 |
+
# Don't crash validation if one batch fails (e.g., OOM on a particularly long sequence)
|
795 |
+
logging.error(f"Runtime error during validation step {batch_idx}: {e}")
|
796 |
+
if "CUDA out of memory" in str(e):
|
797 |
+
logging.warning(
|
798 |
+
f"CUDA OOM error during validation step {batch_idx} with shape src: {src.shape}, tgt: {tgt.shape}. Skipping batch."
|
799 |
+
)
|
800 |
+
gc.collect()
|
801 |
+
torch.cuda.empty_cache()
|
802 |
+
else:
|
803 |
+
logging.error(f"Shapes - src: {src.shape}, tgt: {tgt.shape}")
|
804 |
+
# Return None or a placeholder if needed by some aggregation logic,
|
805 |
+
# but self.log should handle the metric correctly even if some steps fail.
|
806 |
+
return None
|
807 |
+
|
808 |
+
def configure_optimizers(self):
|
809 |
+
optimizer = torch.optim.AdamW(
|
810 |
+
self.parameters(), # self.parameters() includes all model parameters
|
811 |
+
lr=self.hparams.learning_rate,
|
812 |
+
weight_decay=self.hparams.weight_decay,
|
813 |
+
)
|
814 |
+
|
815 |
+
# --- Add Learning Rate Scheduler ---
|
816 |
+
# Use linear warmup followed by linear decay (common for transformers)
|
817 |
+
# Requires the 'transformers' library: pip install transformers
|
818 |
+
try:
|
819 |
+
from transformers import get_linear_schedule_with_warmup
|
820 |
+
|
821 |
+
# Estimate total training steps if trainer is available
|
822 |
+
# estimated_stepping_batches gives steps per epoch * num_epochs / num_devices (if using DDP)
|
823 |
+
# For total steps across all devices * epochs, we might need to calculate differently or use a fixed large number if estimate isn't ready
|
824 |
+
# Let's rely on estimated_stepping_batches, assuming it gives a reasonable estimate of steps the optimizer will take.
|
825 |
+
# Note: Accessing self.trainer here might be tricky if it's not fully initialized yet.
|
826 |
+
# A safer approach might be to calculate based on dataset size and epochs if possible,
|
827 |
+
# or use a very large number for num_training_steps if decay to zero is desired eventually.
|
828 |
+
# Let's try accessing trainer, but add a fallback.
|
829 |
+
try:
|
830 |
+
# This attribute is available after trainer setup, might work here.
|
831 |
+
num_training_steps = self.trainer.estimated_stepping_batches
|
832 |
+
logging.info(
|
833 |
+
f"Estimated stepping batches for LR schedule: {num_training_steps}"
|
834 |
+
)
|
835 |
+
if num_training_steps is None or num_training_steps <= 0:
|
836 |
+
logging.warning(
|
837 |
+
"Could not estimate stepping batches, using fallback for LR schedule."
|
838 |
+
)
|
839 |
+
# Fallback: Calculate based on assumed dataset size / effective batch size * epochs
|
840 |
+
# This requires knowing the dataset size, which isn't directly available here.
|
841 |
+
# Using a large fixed number as a simpler fallback if decay is desired eventually.
|
842 |
+
# Or, calculate based on hparams if dataset size was stored? No.
|
843 |
+
# Let's default to a large number if estimate fails.
|
844 |
+
num_training_steps = 1_000_000 # Adjust this large number if needed
|
845 |
+
except AttributeError:
|
846 |
+
logging.warning(
|
847 |
+
"self.trainer not available yet in configure_optimizers. Using fallback step count for LR schedule."
|
848 |
+
)
|
849 |
+
num_training_steps = 1_000_000 # Adjust this large number if needed
|
850 |
+
|
851 |
+
# Set warmup steps (e.g., 5% of total steps)
|
852 |
+
num_warmup_steps = int(0.05 * num_training_steps)
|
853 |
+
logging.info(
|
854 |
+
f"LR Scheduler: Total steps ~{num_training_steps}, Warmup steps: {num_warmup_steps}"
|
855 |
+
)
|
856 |
+
|
857 |
+
scheduler = get_linear_schedule_with_warmup(
|
858 |
+
optimizer,
|
859 |
+
num_warmup_steps=num_warmup_steps,
|
860 |
+
num_training_steps=num_training_steps,
|
861 |
+
)
|
862 |
+
|
863 |
+
lr_scheduler_config = {
|
864 |
+
"scheduler": scheduler,
|
865 |
+
"interval": "step", # Call scheduler after each training step
|
866 |
+
"frequency": 1,
|
867 |
+
"name": "linear_warmup_decay_lr", # Optional: Name for logging
|
868 |
+
}
|
869 |
+
logging.info("Using Linear Warmup/Decay LR Scheduler.")
|
870 |
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
871 |
+
|
872 |
+
except ImportError:
|
873 |
+
logging.warning(
|
874 |
+
"'transformers' library not found. Cannot create linear warmup scheduler. Using constant LR."
|
875 |
+
)
|
876 |
+
return optimizer
|
877 |
+
except Exception as e:
|
878 |
+
logging.error(
|
879 |
+
f"Error setting up LR scheduler: {e}. Using constant LR.", exc_info=True
|
880 |
+
)
|
881 |
+
return optimizer
|
882 |
+
|
883 |
+
|
884 |
+
# --- 6. Inference (Translation) (No changes needed) ---
|
885 |
+
# These functions remain largely the same but will take the LightningModule instance
|
886 |
+
|
887 |
+
|
888 |
+
def greedy_decode(
|
889 |
+
model: pl.LightningModule, # Takes the LightningModule
|
890 |
+
src: torch.Tensor,
|
891 |
+
src_padding_mask: torch.Tensor,
|
892 |
+
max_len: int,
|
893 |
+
sos_idx: int,
|
894 |
+
eos_idx: int,
|
895 |
+
device: torch.device,
|
896 |
+
) -> torch.Tensor:
|
897 |
+
"""Performs greedy decoding using the LightningModule's model."""
|
898 |
+
# model.eval() # Lightning handles eval mode during inference/testing
|
899 |
+
transformer_model = model.model # Access the underlying Seq2SeqTransformer
|
900 |
+
|
901 |
+
try:
|
902 |
+
with torch.no_grad():
|
903 |
+
# Use the model's encode/decode methods
|
904 |
+
memory = transformer_model.encode(
|
905 |
+
src, src_padding_mask
|
906 |
+
) # [1, src_len, emb_size]
|
907 |
+
memory = memory.to(device)
|
908 |
+
# Ensure memory_key_padding_mask is also on the correct device for decode
|
909 |
+
memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
|
910 |
+
|
911 |
+
ys = (
|
912 |
+
torch.ones(1, 1).fill_(sos_idx).type(torch.long).to(device)
|
913 |
+
) # [1, 1] (Batch size 1)
|
914 |
+
|
915 |
+
for i in range(max_len - 1):
|
916 |
+
tgt_seq_len = ys.shape[1]
|
917 |
+
# Create masks for the current decoded sequence length
|
918 |
+
tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
|
919 |
+
device
|
920 |
+
) # [curr_len, curr_len]
|
921 |
+
# No padding in target during greedy decode yet
|
922 |
+
tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool).to(
|
923 |
+
device
|
924 |
+
) # [1, curr_len]
|
925 |
+
|
926 |
+
# Use the model's decode method
|
927 |
+
out = transformer_model.decode(
|
928 |
+
ys, memory, tgt_mask, tgt_padding_mask, memory_key_padding_mask
|
929 |
+
)
|
930 |
+
# out: [1, curr_len, emb_size]
|
931 |
+
|
932 |
+
# Get the logits for the last token generated
|
933 |
+
last_token_logits = transformer_model.generator(
|
934 |
+
out[:, -1, :]
|
935 |
+
) # [1, tgt_vocab_size]
|
936 |
+
prob = last_token_logits # Use logits directly for argmax
|
937 |
+
_, next_word = torch.max(prob, dim=1)
|
938 |
+
next_word = next_word.item()
|
939 |
+
|
940 |
+
# Append the predicted token ID
|
941 |
+
ys = torch.cat(
|
942 |
+
[ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1
|
943 |
+
)
|
944 |
+
|
945 |
+
# Stop if EOS token is generated
|
946 |
+
if next_word == eos_idx:
|
947 |
+
break
|
948 |
+
# Return the generated sequence, excluding the initial SOS token
|
949 |
+
return ys[:, 1:]
|
950 |
+
|
951 |
+
except RuntimeError as e:
|
952 |
+
logging.error(f"Runtime error during greedy decode: {e}")
|
953 |
+
if "CUDA out of memory" in str(e):
|
954 |
+
gc.collect()
|
955 |
+
torch.cuda.empty_cache()
|
956 |
+
# Return an empty tensor on error
|
957 |
+
return torch.tensor([[]], dtype=torch.long, device=device)
|
958 |
+
|
959 |
+
|
960 |
+
def translate(
|
961 |
+
model: pl.LightningModule, # Takes the LightningModule
|
962 |
+
src_sentence: str,
|
963 |
+
smiles_tokenizer,
|
964 |
+
iupac_tokenizer,
|
965 |
+
device: torch.device,
|
966 |
+
max_len: int,
|
967 |
+
sos_idx: int,
|
968 |
+
eos_idx: int,
|
969 |
+
pad_idx: int,
|
970 |
+
) -> str:
|
971 |
+
"""Translates a single SMILES string using the LightningModule."""
|
972 |
+
model.eval() # Ensure model is in eval mode for inference
|
973 |
+
|
974 |
+
try:
|
975 |
+
src_encoded = smiles_tokenizer.encode(src_sentence)
|
976 |
+
if not src_encoded or len(src_encoded.ids) == 0:
|
977 |
+
logging.warning(f"Encoding failed for SMILES: {src_sentence}")
|
978 |
+
return "[Encoding Error]"
|
979 |
+
# Truncate source sequence if needed before creating tensor
|
980 |
+
src_ids = src_encoded.ids[:max_len]
|
981 |
+
if not src_ids:
|
982 |
+
logging.warning(
|
983 |
+
f"Source sequence empty after truncation for SMILES: {src_sentence}"
|
984 |
+
)
|
985 |
+
return "[Encoding Error - Empty Src]"
|
986 |
+
|
987 |
+
except Exception as e:
|
988 |
+
logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
|
989 |
+
return "[Encoding Error]"
|
990 |
+
|
991 |
+
# Create tensor and move to device
|
992 |
+
src = (
|
993 |
+
torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
|
994 |
+
) # Add batch dimension
|
995 |
+
# Create padding mask (boolean, True where padded)
|
996 |
+
# For single sentence inference, there's no padding unless the original sequence was shorter than max_len
|
997 |
+
# and we padded it, but here we just take the IDs. The mask should reflect the actual length.
|
998 |
+
# However, the model expects a mask, even if it's all False for non-padded sequences.
|
999 |
+
src_padding_mask = src == pad_idx # [1, src_len]
|
1000 |
+
|
1001 |
+
# Perform greedy decoding
|
1002 |
+
tgt_tokens_tensor = greedy_decode(
|
1003 |
+
model=model, # Pass the LightningModule
|
1004 |
+
src=src,
|
1005 |
+
src_padding_mask=src_padding_mask,
|
1006 |
+
max_len=max_len, # Use the configured max_len for generation limit
|
1007 |
+
sos_idx=sos_idx,
|
1008 |
+
eos_idx=eos_idx,
|
1009 |
+
device=device,
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
# Decode the generated token IDs
|
1013 |
+
if tgt_tokens_tensor.numel() > 0:
|
1014 |
+
tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
|
1015 |
+
try:
|
1016 |
+
# Decode using the target tokenizer, skipping special tokens like <pad>, <sos>, <eos>
|
1017 |
+
translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
1018 |
+
return translation
|
1019 |
+
except Exception as e:
|
1020 |
+
logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
|
1021 |
+
return "[Decoding Error]"
|
1022 |
+
else:
|
1023 |
+
# Log if decoding returned an empty tensor (might happen on error in greedy_decode)
|
1024 |
+
# logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}")
|
1025 |
+
return "[Decoding Error - Empty Output]"
|
1026 |
+
|
1027 |
+
|
1028 |
+
# --- 7. Main Execution Script (Minor updates for clarity) ---
|
1029 |
+
if __name__ == "__main__":
|
1030 |
+
pl.seed_everything(RANDOM_SEED, workers=True) # Seed everything for reproducibility
|
1031 |
+
|
1032 |
+
# --- Create Checkpoint Directory ---
|
1033 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
1034 |
+
|
1035 |
+
# --- Load Data from CSV and Split ---
|
1036 |
+
# (Keep this data preparation step outside the Lightning Module)
|
1037 |
+
logging.info(f"Loading and splitting data from {INPUT_CSV_FILE}...")
|
1038 |
+
# (Re-using the data loading and splitting logic from the original script)
|
1039 |
+
try:
|
1040 |
+
# Load with dtype specification for potentially large files
|
1041 |
+
df = pd.read_csv(INPUT_CSV_FILE, dtype={"SMILES": str, "Systematic": str})
|
1042 |
+
logging.info(f"Initial rows loaded: {len(df)}")
|
1043 |
+
if "SMILES" not in df.columns:
|
1044 |
+
raise ValueError("CSV must contain 'SMILES' column.")
|
1045 |
+
if "Systematic" not in df.columns:
|
1046 |
+
raise ValueError("CSV must contain 'Systematic' (IUPAC name) column.")
|
1047 |
+
df.rename(columns={"Systematic": "IUPAC"}, inplace=True)
|
1048 |
+
|
1049 |
+
initial_rows = len(df)
|
1050 |
+
df.dropna(subset=["SMILES", "IUPAC"], inplace=True)
|
1051 |
+
rows_after_na = len(df)
|
1052 |
+
if initial_rows > rows_after_na:
|
1053 |
+
logging.info(
|
1054 |
+
f"Dropped {initial_rows - rows_after_na} rows with missing values."
|
1055 |
+
)
|
1056 |
+
# Strip whitespace and filter empty strings more efficiently
|
1057 |
+
df = df[df["SMILES"].str.strip().astype(bool)]
|
1058 |
+
df = df[df["IUPAC"].str.strip().astype(bool)]
|
1059 |
+
df["SMILES"] = df["SMILES"].str.strip()
|
1060 |
+
df["IUPAC"] = df["IUPAC"].str.strip()
|
1061 |
+
rows_after_empty = len(df)
|
1062 |
+
if rows_after_na > rows_after_empty:
|
1063 |
+
logging.info(
|
1064 |
+
f"Dropped {rows_after_na - rows_after_empty} rows with empty strings after stripping."
|
1065 |
+
)
|
1066 |
+
|
1067 |
+
smiles_data = df["SMILES"].tolist()
|
1068 |
+
iupac_data = df["IUPAC"].tolist()
|
1069 |
+
logging.info(f"Loaded {len(smiles_data)} valid pairs from CSV.")
|
1070 |
+
del df
|
1071 |
+
gc.collect() # Free memory
|
1072 |
+
|
1073 |
+
if len(smiles_data) < 10:
|
1074 |
+
raise ValueError(
|
1075 |
+
f"Not enough valid data ({len(smiles_data)}) for split. Need at least 10."
|
1076 |
+
)
|
1077 |
+
|
1078 |
+
train_smi, val_smi, train_iupac, val_iupac = train_test_split(
|
1079 |
+
smiles_data,
|
1080 |
+
iupac_data,
|
1081 |
+
test_size=VALIDATION_SPLIT,
|
1082 |
+
random_state=RANDOM_SEED,
|
1083 |
+
)
|
1084 |
+
logging.info(f"Split: {len(train_smi)} train, {len(val_smi)} validation.")
|
1085 |
+
del smiles_data, iupac_data
|
1086 |
+
gc.collect() # Free memory
|
1087 |
+
|
1088 |
+
logging.info("Writing split data to files...")
|
1089 |
+
with open(TRAIN_SMILES_FILE, "w", encoding="utf-8") as f:
|
1090 |
+
f.write("\n".join(train_smi))
|
1091 |
+
with open(TRAIN_IUPAC_FILE, "w", encoding="utf-8") as f:
|
1092 |
+
f.write("\n".join(train_iupac))
|
1093 |
+
with open(VAL_SMILES_FILE, "w", encoding="utf-8") as f:
|
1094 |
+
f.write("\n".join(val_smi))
|
1095 |
+
with open(VAL_IUPAC_FILE, "w", encoding="utf-8") as f:
|
1096 |
+
f.write("\n".join(val_iupac))
|
1097 |
+
logging.info(
|
1098 |
+
f"Split files written: {TRAIN_SMILES_FILE}, {TRAIN_IUPAC_FILE}, {VAL_SMILES_FILE}, {VAL_IUPAC_FILE}"
|
1099 |
+
)
|
1100 |
+
del train_smi, val_smi, train_iupac, val_iupac
|
1101 |
+
gc.collect() # Free memory
|
1102 |
+
|
1103 |
+
except FileNotFoundError:
|
1104 |
+
logging.error(f"Fatal error: Input CSV file not found at {INPUT_CSV_FILE}")
|
1105 |
+
exit(1)
|
1106 |
+
except ValueError as ve:
|
1107 |
+
logging.error(f"Fatal error during data preparation: {ve}")
|
1108 |
+
exit(1)
|
1109 |
+
except Exception as e:
|
1110 |
+
logging.error(f"Fatal error during data preparation: {e}", exc_info=True)
|
1111 |
+
exit(1)
|
1112 |
+
# --- End Data Preparation ---
|
1113 |
+
|
1114 |
+
# --- Initialize Tokenizers ---
|
1115 |
+
logging.info("Initializing Tokenizers...")
|
1116 |
+
# Ensure training files exist before attempting to train tokenizers
|
1117 |
+
if not os.path.exists(TRAIN_SMILES_FILE) or not os.path.exists(TRAIN_IUPAC_FILE):
|
1118 |
+
logging.error(
|
1119 |
+
f"Training files ({TRAIN_SMILES_FILE}, {TRAIN_IUPAC_FILE}) not found. Cannot train tokenizers."
|
1120 |
+
)
|
1121 |
+
exit(1)
|
1122 |
+
|
1123 |
+
smiles_tokenizer = get_smiles_tokenizer(
|
1124 |
+
train_files=[TRAIN_SMILES_FILE],
|
1125 |
+
vocab_size=SRC_VOCAB_SIZE_ESTIMATE,
|
1126 |
+
tokenizer_path=SMILES_TOKENIZER_FILE,
|
1127 |
+
)
|
1128 |
+
iupac_tokenizer = get_iupac_tokenizer(
|
1129 |
+
train_files=[TRAIN_IUPAC_FILE],
|
1130 |
+
vocab_size=TGT_VOCAB_SIZE_ESTIMATE,
|
1131 |
+
tokenizer_path=IUPAC_TOKENIZER_FILE,
|
1132 |
+
)
|
1133 |
+
|
1134 |
+
ACTUAL_SRC_VOCAB_SIZE = smiles_tokenizer.get_vocab_size()
|
1135 |
+
ACTUAL_TGT_VOCAB_SIZE = iupac_tokenizer.get_vocab_size()
|
1136 |
+
logging.info(f"Actual SMILES Vocab Size: {ACTUAL_SRC_VOCAB_SIZE}")
|
1137 |
+
logging.info(f"Actual IUPAC Vocab Size: {ACTUAL_TGT_VOCAB_SIZE}")
|
1138 |
+
# Update hparams with actual sizes (will be logged by WandbLogger)
|
1139 |
+
hparams["actual_src_vocab_size"] = ACTUAL_SRC_VOCAB_SIZE
|
1140 |
+
hparams["actual_tgt_vocab_size"] = ACTUAL_TGT_VOCAB_SIZE
|
1141 |
+
|
1142 |
+
# --- Setup WandB Logger ---
|
1143 |
+
# Ensure WANDB_ENTITY is set if required, otherwise it uses default
|
1144 |
+
if WANDB_ENTITY is None:
|
1145 |
+
logging.warning(
|
1146 |
+
"WANDB_ENTITY not set. WandB will log to your default entity. Set WANDB_ENTITY='your_username_or_team' to specify."
|
1147 |
+
)
|
1148 |
+
|
1149 |
+
wandb_logger = WandbLogger(
|
1150 |
+
project=WANDB_PROJECT,
|
1151 |
+
entity=WANDB_ENTITY, # Set your entity here or leave as None
|
1152 |
+
name=WANDB_RUN_NAME,
|
1153 |
+
config=hparams, # Log hyperparameters defined above
|
1154 |
+
# log_model='all' # Log model checkpoints to WandB (can consume significant storage)
|
1155 |
+
# log_model=True # Log best model checkpoint based on monitor
|
1156 |
+
)
|
1157 |
+
|
1158 |
+
# --- Initialize Datasets and DataLoaders ---
|
1159 |
+
logging.info("Creating Datasets and DataLoaders...")
|
1160 |
+
try:
|
1161 |
+
train_dataset = SmilesIupacDataset(TRAIN_SMILES_FILE, TRAIN_IUPAC_FILE)
|
1162 |
+
val_dataset = SmilesIupacDataset(VAL_SMILES_FILE, VAL_IUPAC_FILE)
|
1163 |
+
if len(train_dataset) == 0 or len(val_dataset) == 0:
|
1164 |
+
logging.error(
|
1165 |
+
"Training or validation dataset is empty. Check data splitting and file content."
|
1166 |
+
)
|
1167 |
+
exit(1)
|
1168 |
+
except Exception as e:
|
1169 |
+
logging.error(f"Error creating Datasets: {e}", exc_info=True)
|
1170 |
+
exit(1)
|
1171 |
+
|
1172 |
+
# Create partial function for collate_fn to pass tokenizers and params
|
1173 |
+
def collate_fn_partial(batch):
|
1174 |
+
return collate_fn(
|
1175 |
+
batch,
|
1176 |
+
smiles_tokenizer,
|
1177 |
+
iupac_tokenizer,
|
1178 |
+
PAD_IDX,
|
1179 |
+
SOS_IDX,
|
1180 |
+
EOS_IDX,
|
1181 |
+
hparams["max_len"],
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
# Use persistent_workers=True if num_workers > 0 for efficiency, especially with DDP
|
1185 |
+
persistent_workers = NUM_WORKERS > 0 and STRATEGY == "ddp" # Recommended for DDP
|
1186 |
+
|
1187 |
+
train_dataloader = DataLoader(
|
1188 |
+
train_dataset,
|
1189 |
+
batch_size=BATCH_SIZE_PER_GPU,
|
1190 |
+
shuffle=True,
|
1191 |
+
collate_fn=collate_fn_partial,
|
1192 |
+
num_workers=NUM_WORKERS,
|
1193 |
+
pin_memory=True,
|
1194 |
+
persistent_workers=persistent_workers,
|
1195 |
+
drop_last=True,
|
1196 |
+
) # Drop last incomplete batch in training for DDP consistency
|
1197 |
+
val_dataloader = DataLoader(
|
1198 |
+
val_dataset,
|
1199 |
+
batch_size=BATCH_SIZE_PER_GPU, # Use same batch size for validation
|
1200 |
+
shuffle=False,
|
1201 |
+
collate_fn=collate_fn_partial,
|
1202 |
+
num_workers=NUM_WORKERS,
|
1203 |
+
pin_memory=True,
|
1204 |
+
persistent_workers=persistent_workers,
|
1205 |
+
drop_last=False,
|
1206 |
+
) # Keep all validation batches
|
1207 |
+
|
1208 |
+
# --- Initialize Model ---
|
1209 |
+
logging.info("Initializing Lightning Module...")
|
1210 |
+
# Pass hparams dictionary directly, PTL handles it via save_hyperparameters
|
1211 |
+
model = SmilesIupacLitModule(
|
1212 |
+
src_vocab_size=ACTUAL_SRC_VOCAB_SIZE,
|
1213 |
+
tgt_vocab_size=ACTUAL_TGT_VOCAB_SIZE,
|
1214 |
+
hparams_dict=hparams,
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
# Optional: Log model topology to WandB (do this after model init, before training)
|
1218 |
+
# Note: watch can sometimes slow down training start, especially with large models
|
1219 |
+
# wandb_logger.watch(model, log='all', log_freq=100) # Log gradients and parameters
|
1220 |
+
|
1221 |
+
# --- Define Callbacks ---
|
1222 |
+
checkpoint_callback = ModelCheckpoint(
|
1223 |
+
dirpath=CHECKPOINT_DIR,
|
1224 |
+
filename=BEST_MODEL_FILENAME + "-{epoch:02d}-{val_loss:.4f}",
|
1225 |
+
save_top_k=1, # Save only the best model
|
1226 |
+
verbose=True,
|
1227 |
+
monitor="val_loss", # Monitor validation loss
|
1228 |
+
mode="min", # Save the model with the minimum validation loss
|
1229 |
+
save_last=True, # Optionally save the last checkpoint as well
|
1230 |
+
)
|
1231 |
+
early_stopping_callback = EarlyStopping(
|
1232 |
+
monitor="val_loss",
|
1233 |
+
patience=PATIENCE, # Number of epochs with no improvement after which training will be stopped
|
1234 |
+
verbose=True,
|
1235 |
+
mode="min",
|
1236 |
+
)
|
1237 |
+
|
1238 |
+
# --- Initialize PyTorch Lightning Trainer ---
|
1239 |
+
logging.info(
|
1240 |
+
f"Initializing PyTorch Lightning Trainer (GPUs={DEVICES}, Strategy='{STRATEGY}', Precision='{PRECISION}')..."
|
1241 |
+
)
|
1242 |
+
trainer = pl.Trainer(
|
1243 |
+
accelerator=ACCELERATOR,
|
1244 |
+
devices=DEVICES,
|
1245 |
+
strategy=STRATEGY,
|
1246 |
+
precision=PRECISION,
|
1247 |
+
max_epochs=NUM_EPOCHS,
|
1248 |
+
logger=wandb_logger, # Use WandbLogger
|
1249 |
+
callbacks=[checkpoint_callback, early_stopping_callback],
|
1250 |
+
gradient_clip_val=GRAD_CLIP_NORM, # Gradient clipping
|
1251 |
+
accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES, # Gradient accumulation
|
1252 |
+
log_every_n_steps=50, # How often to log metrics (steps across all GPUs)
|
1253 |
+
# deterministic=True, # Might slow down training, use for debugging reproducibility if needed
|
1254 |
+
# profiler="simple", # Optional: Add profiler ("simple", "advanced", "pytorch") for performance analysis
|
1255 |
+
# Checkpointing behavior is controlled by ModelCheckpoint callback
|
1256 |
+
# enable_checkpointing=True, # Default is True if callbacks has ModelCheckpoint
|
1257 |
+
)
|
1258 |
+
|
1259 |
+
# --- Start Training ---
|
1260 |
+
logging.info(
|
1261 |
+
f"Starting training with Effective Batch Size: {hparams['effective_batch_size']}..."
|
1262 |
+
)
|
1263 |
+
start_time = time.time()
|
1264 |
+
try:
|
1265 |
+
trainer.fit(model, train_dataloader, val_dataloader)
|
1266 |
+
training_duration = time.time() - start_time
|
1267 |
+
logging.info(
|
1268 |
+
f"Training finished in {training_duration / 3600:.2f} hours ({training_duration:.2f} seconds)."
|
1269 |
+
)
|
1270 |
+
|
1271 |
+
# Log best model path and score
|
1272 |
+
best_path = checkpoint_callback.best_model_path
|
1273 |
+
best_score = checkpoint_callback.best_model_score # This is a tensor, get value
|
1274 |
+
if best_score is not None:
|
1275 |
+
logging.info(
|
1276 |
+
f"Best model checkpoint saved at: {best_path} with val_loss: {best_score.item():.4f}"
|
1277 |
+
)
|
1278 |
+
# Log best score to wandb summary
|
1279 |
+
wandb_logger.experiment.summary["best_val_loss"] = best_score.item()
|
1280 |
+
wandb_logger.experiment.summary["best_model_path"] = best_path
|
1281 |
+
else:
|
1282 |
+
logging.warning(
|
1283 |
+
"Could not retrieve best model score from checkpoint callback."
|
1284 |
+
)
|
1285 |
+
|
1286 |
+
except Exception as e:
|
1287 |
+
logging.error(f"Fatal error during training: {e}", exc_info=True)
|
1288 |
+
# Ensure wandb run is finished even on error
|
1289 |
+
if wandb.run is not None:
|
1290 |
+
wandb.finish(exit_code=1) # Mark as failed run
|
1291 |
+
exit(1)
|
1292 |
+
|
1293 |
+
# --- Load Best Model for Final Translation Examples ---
|
1294 |
+
best_model_path_to_load = checkpoint_callback.best_model_path
|
1295 |
+
logging.info(
|
1296 |
+
f"\nLoading best model from {best_model_path_to_load} for translation examples..."
|
1297 |
+
)
|
1298 |
+
final_model = None
|
1299 |
+
if best_model_path_to_load and os.path.exists(best_model_path_to_load):
|
1300 |
+
try:
|
1301 |
+
# Load the model using the Lightning checkpoint loading mechanism
|
1302 |
+
# Pass hparams_dict again in case it's needed and not perfectly saved/loaded
|
1303 |
+
final_model = SmilesIupacLitModule.load_from_checkpoint(
|
1304 |
+
best_model_path_to_load,
|
1305 |
+
# Provide necessary args again if they weren't saved in hparams properly
|
1306 |
+
# (though save_hyperparameters should handle this)
|
1307 |
+
src_vocab_size=ACTUAL_SRC_VOCAB_SIZE,
|
1308 |
+
tgt_vocab_size=ACTUAL_TGT_VOCAB_SIZE,
|
1309 |
+
hparams_dict=hparams, # Pass the original hparams
|
1310 |
+
)
|
1311 |
+
# Determine device for inference (use the first GPU if available)
|
1312 |
+
inference_device = torch.device(
|
1313 |
+
f"{ACCELERATOR}:0"
|
1314 |
+
if ACCELERATOR == "gpu" and torch.cuda.is_available()
|
1315 |
+
else "cpu"
|
1316 |
+
)
|
1317 |
+
final_model = final_model.to(inference_device)
|
1318 |
+
final_model.eval() # Set to evaluation mode
|
1319 |
+
final_model.freeze() # Freeze weights for inference
|
1320 |
+
logging.info(
|
1321 |
+
f"Best model loaded successfully to {inference_device} for final translation."
|
1322 |
+
)
|
1323 |
+
except Exception as e:
|
1324 |
+
logging.error(
|
1325 |
+
f"Error loading saved model from {best_model_path_to_load}: {e}",
|
1326 |
+
exc_info=True,
|
1327 |
+
)
|
1328 |
+
final_model = None # Ensure final_model is None if loading fails
|
1329 |
+
else:
|
1330 |
+
logging.error(
|
1331 |
+
f"Error: Best model checkpoint path not found or invalid: '{best_model_path_to_load}'. Cannot perform final translation."
|
1332 |
+
)
|
1333 |
+
|
1334 |
+
# --- Example Translation (using some validation samples) ---
|
1335 |
+
if final_model:
|
1336 |
+
logging.info("\n--- Example Translations (using validation data) ---")
|
1337 |
+
num_examples = 20 # Show more examples
|
1338 |
+
try:
|
1339 |
+
# Load validation samples directly from the files
|
1340 |
+
val_smi_examples = []
|
1341 |
+
val_iupac_examples = []
|
1342 |
+
if os.path.exists(VAL_SMILES_FILE) and os.path.exists(VAL_IUPAC_FILE):
|
1343 |
+
with (
|
1344 |
+
open(VAL_SMILES_FILE, "r", encoding="utf-8") as f_smi,
|
1345 |
+
open(VAL_IUPAC_FILE, "r", encoding="utf-8") as f_iupac,
|
1346 |
+
):
|
1347 |
+
for i, (smi_line, iupac_line) in enumerate(zip(f_smi, f_iupac)):
|
1348 |
+
if i >= num_examples:
|
1349 |
+
break
|
1350 |
+
val_smi_examples.append(smi_line.strip())
|
1351 |
+
val_iupac_examples.append(iupac_line.strip())
|
1352 |
+
else:
|
1353 |
+
logging.warning(
|
1354 |
+
f"Validation files ({VAL_SMILES_FILE}, {VAL_IUPAC_FILE}) not found. Cannot show examples."
|
1355 |
+
)
|
1356 |
+
|
1357 |
+
if len(val_smi_examples) > 0:
|
1358 |
+
print("\n" + "=" * 40)
|
1359 |
+
print(
|
1360 |
+
f"Example Translations (First {len(val_smi_examples)} Validation Samples)"
|
1361 |
+
)
|
1362 |
+
print("=" * 40)
|
1363 |
+
# Use the device the model was loaded onto
|
1364 |
+
inference_device = next(final_model.parameters()).device
|
1365 |
+
translation_examples = [] # For potential logging to wandb
|
1366 |
+
for i in range(len(val_smi_examples)):
|
1367 |
+
smi = val_smi_examples[i]
|
1368 |
+
true_iupac = val_iupac_examples[i]
|
1369 |
+
predicted_iupac = translate(
|
1370 |
+
model=final_model, # Use the loaded best model
|
1371 |
+
src_sentence=smi,
|
1372 |
+
smiles_tokenizer=smiles_tokenizer,
|
1373 |
+
iupac_tokenizer=iupac_tokenizer,
|
1374 |
+
device=inference_device, # Use model's device
|
1375 |
+
max_len=hparams["max_len"],
|
1376 |
+
sos_idx=SOS_IDX,
|
1377 |
+
eos_idx=EOS_IDX,
|
1378 |
+
pad_idx=PAD_IDX,
|
1379 |
+
)
|
1380 |
+
print(f"\nExample {i + 1}:")
|
1381 |
+
print(f" SMILES: {smi}")
|
1382 |
+
print(f" True IUPAC: {true_iupac}")
|
1383 |
+
print(f" Predicted IUPAC: {predicted_iupac}")
|
1384 |
+
print("-" * 30)
|
1385 |
+
# Prepare data for wandb table
|
1386 |
+
translation_examples.append([smi, true_iupac, predicted_iupac])
|
1387 |
+
|
1388 |
+
print("=" * 40 + "\n")
|
1389 |
+
|
1390 |
+
# Log examples to a WandB Table
|
1391 |
+
try:
|
1392 |
+
columns = ["SMILES", "True IUPAC", "Predicted IUPAC"]
|
1393 |
+
wandb_table = wandb.Table(
|
1394 |
+
data=translation_examples, columns=columns
|
1395 |
+
)
|
1396 |
+
wandb_logger.experiment.log(
|
1397 |
+
{"validation_translations": wandb_table}
|
1398 |
+
)
|
1399 |
+
logging.info("Logged translation examples to WandB Table.")
|
1400 |
+
except Exception as wb_err:
|
1401 |
+
logging.error(
|
1402 |
+
f"Failed to log translation examples to WandB: {wb_err}"
|
1403 |
+
)
|
1404 |
+
|
1405 |
+
else:
|
1406 |
+
logging.warning("Could not load validation samples for examples.")
|
1407 |
+
except Exception as e:
|
1408 |
+
logging.error(f"Error during example translation phase: {e}", exc_info=True)
|
1409 |
+
else:
|
1410 |
+
logging.warning(
|
1411 |
+
"Skipping final translation examples as the best model could not be loaded."
|
1412 |
+
)
|
1413 |
+
|
1414 |
+
# --- Finish WandB Run ---
|
1415 |
+
if wandb.run is not None:
|
1416 |
+
wandb.finish()
|
1417 |
+
logging.info("WandB run finished.")
|
1418 |
+
else:
|
1419 |
+
logging.info("No active WandB run to finish.")
|
1420 |
+
|
1421 |
+
logging.info("Script finished.")
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
pytorch_lightning
|
3 |
+
json
|
4 |
+
logging
|
5 |
+
tokenizers
|
6 |
+
transformers
|
7 |
+
math
|
8 |
+
gc
|
9 |
+
os
|
10 |
+
gradio
|
11 |
+
huggingface_hub
|
12 |
+
pandas
|
13 |
+
sklearn
|
14 |
+
scikit-learn
|
15 |
+
time
|
16 |
+
wandb
|