AdrianM0 commited on
Commit
aaafea4
·
verified ·
1 Parent(s): 31e7bf8

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +2 -8
  2. app.py +503 -0
  3. enhanced_trainer.py +1421 -0
  4. requirements.txt +16 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Smi2iupac
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.25.2
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: smi2iupac
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.25.2
 
 
6
  ---
 
 
app.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn.functional as F # <--- Added import
5
+ import pytorch_lightning as pl # <--- Added import (needed for type hints, model access)
6
+ import os
7
+ import json
8
+ import logging
9
+ from tokenizers import Tokenizer
10
+ from huggingface_hub import hf_hub_download
11
+ import gc # For garbage collection on potential OOM
12
+ import math # Needed for PositionalEncoding if moved here (or keep in enhanced_trainer)
13
+
14
+ # --- Configuration ---
15
+ MODEL_REPO_ID = "AdrianM0/smiles-to-iupac-translator"
16
+ CHECKPOINT_FILENAME = "last.ckpt"
17
+ SMILES_TOKENIZER_FILENAME = "smiles_bytelevel_bpe_tokenizer_scaled.json"
18
+ IUPAC_TOKENIZER_FILENAME = "iupac_unigram_tokenizer_scaled.json"
19
+ CONFIG_FILENAME = "config.json"
20
+ # --- End Configuration ---
21
+
22
+ # --- Logging ---
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+
25
+ # --- Load Helper Code (Only Model Definition Needed) ---
26
+ try:
27
+ # We only need the LightningModule definition and the mask function now
28
+ from enhanced_trainer import SmilesIupacLitModule, generate_square_subsequent_mask
29
+ logging.info("Successfully imported from enhanced_trainer.py.")
30
+
31
+ # We will define beam_search_decode and translate locally in this file
32
+ # REMOVED: from test_ckpt import beam_search_decode, translate
33
+
34
+ except ImportError as e:
35
+ logging.error(f"Failed to import helper code from enhanced_trainer.py: {e}. Make sure enhanced_trainer.py is in the root of the Hugging Face repo '{MODEL_REPO_ID}'.")
36
+ gr.Error(f"Initialization Error: Could not load necessary Python modules (enhanced_trainer.py). Check Space logs. Error: {e}")
37
+ exit()
38
+ except Exception as e:
39
+ logging.error(f"An unexpected error occurred during helper code import: {e}", exc_info=True)
40
+ gr.Error(f"Initialization Error: An unexpected error occurred loading helper modules. Check Space logs. Error: {e}")
41
+ exit()
42
+
43
+ # --- Global Variables (Load Model Once) ---
44
+ model: pl.LightningModule | None = None # Added type hint
45
+ smiles_tokenizer: Tokenizer | None = None
46
+ iupac_tokenizer: Tokenizer | None = None
47
+ device: torch.device | None = None
48
+ config: dict | None = None
49
+
50
+ # --- Beam Search Decoding Logic (Moved from test_ckpt.py) ---
51
+
52
+ def beam_search_decode(
53
+ model: pl.LightningModule,
54
+ src: torch.Tensor,
55
+ src_padding_mask: torch.Tensor,
56
+ max_len: int,
57
+ sos_idx: int,
58
+ eos_idx: int,
59
+ pad_idx: int, # Needed for padding mask check if src has padding
60
+ device: torch.device,
61
+ beam_width: int = 5,
62
+ n_best: int = 5, # Number of top sequences to return
63
+ length_penalty: float = 0.6 # Alpha for length normalization (0=no penalty, 1=full penalty)
64
+ ) -> list[torch.Tensor]:
65
+ """
66
+ Performs beam search decoding using the LightningModule's model.
67
+ (Code copied and pasted from test_ckpt.py)
68
+ """
69
+ # Ensure model is in eval mode (redundant if called after model.eval(), but safe)
70
+ model.eval()
71
+ transformer_model = model.model # Access the underlying Seq2SeqTransformer
72
+ n_best = min(n_best, beam_width) # Cannot return more than beam_width sequences
73
+
74
+ try:
75
+ with torch.no_grad():
76
+ # --- Encode Source ---
77
+ memory = transformer_model.encode(src, src_padding_mask) # [1, src_len, emb_size]
78
+ memory = memory.to(device)
79
+ # Ensure memory_key_padding_mask is also on the correct device for decode
80
+ memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
81
+
82
+ # --- Initialize Beams ---
83
+ initial_beam_seq = torch.ones(1, 1, dtype=torch.long, device=device).fill_(sos_idx) # [1, 1]
84
+ initial_beam_score = torch.zeros(1, dtype=torch.float, device=device) # [1]
85
+ active_beams = [(initial_beam_seq, initial_beam_score)]
86
+ finished_beams = []
87
+
88
+ # --- Decoding Loop ---
89
+ for step in range(max_len - 1):
90
+ if not active_beams:
91
+ break
92
+
93
+ potential_next_beams = []
94
+ for current_seq, current_score in active_beams:
95
+ if current_seq[0, -1].item() == eos_idx:
96
+ finished_beams.append((current_seq, current_score))
97
+ continue
98
+
99
+ tgt_input = current_seq # [1, current_len]
100
+ tgt_seq_len = tgt_input.shape[1]
101
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(device) # [curr_len, curr_len]
102
+ tgt_padding_mask = torch.zeros(tgt_input.shape, dtype=torch.bool, device=device) # [1, curr_len]
103
+
104
+ decoder_output = transformer_model.decode(
105
+ tgt=tgt_input,
106
+ memory=memory,
107
+ tgt_mask=tgt_mask,
108
+ tgt_padding_mask=tgt_padding_mask,
109
+ memory_key_padding_mask=memory_key_padding_mask
110
+ ) # [1, curr_len, emb_size]
111
+
112
+ next_token_logits = transformer_model.generator(decoder_output[:, -1, :]) # [1, tgt_vocab_size]
113
+ log_probs = F.log_softmax(next_token_logits, dim=-1) # [1, tgt_vocab_size]
114
+
115
+ topk_log_probs, topk_indices = torch.topk(log_probs + current_score, beam_width, dim=-1)
116
+
117
+ for i in range(beam_width):
118
+ next_token_id = topk_indices[0, i].item()
119
+ next_score = topk_log_probs[0, i].reshape(1) # Keep as tensor [1]
120
+ next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=device) # [1, 1]
121
+ new_seq = torch.cat([current_seq, next_token_tensor], dim=1) # [1, current_len + 1]
122
+ potential_next_beams.append((new_seq, next_score))
123
+
124
+ potential_next_beams.sort(key=lambda x: x[1].item(), reverse=True)
125
+
126
+ active_beams = []
127
+ added_count = 0
128
+ for seq, score in potential_next_beams:
129
+ is_finished = seq[0, -1].item() == eos_idx
130
+ if is_finished:
131
+ finished_beams.append((seq, score))
132
+ elif added_count < beam_width:
133
+ active_beams.append((seq, score))
134
+ added_count += 1
135
+ elif added_count >= beam_width:
136
+ break
137
+
138
+ finished_beams.extend(active_beams)
139
+
140
+ # Apply length penalty and sort
141
+ # Handle potential division by zero if sequence length is 1 (or 0?)
142
+ def get_score(beam_tuple):
143
+ seq, score = beam_tuple
144
+ seq_len = seq.shape[1]
145
+ if length_penalty == 0.0 or seq_len <= 1:
146
+ return score.item()
147
+ else:
148
+ # Ensure seq_len is float for pow
149
+ return score.item() / (float(seq_len) ** length_penalty)
150
+
151
+ finished_beams.sort(key=get_score, reverse=True) # Higher score is better
152
+
153
+ top_sequences = [seq[:, 1:] for seq, score in finished_beams[:n_best]] # seq shape [1, len] -> [1, len-1]
154
+ return top_sequences
155
+
156
+ except RuntimeError as e:
157
+ logging.error(f"Runtime error during beam search decode: {e}")
158
+ if "CUDA out of memory" in str(e):
159
+ gc.collect(); torch.cuda.empty_cache()
160
+ return [] # Return empty list on error
161
+ except Exception as e:
162
+ logging.error(f"Unexpected error during beam search decode: {e}", exc_info=True)
163
+ return []
164
+
165
+ # --- Translation Function (Moved from test_ckpt.py) ---
166
+
167
+ def translate(
168
+ model: pl.LightningModule,
169
+ src_sentence: str,
170
+ smiles_tokenizer: Tokenizer,
171
+ iupac_tokenizer: Tokenizer,
172
+ device: torch.device,
173
+ max_len: int,
174
+ sos_idx: int,
175
+ eos_idx: int,
176
+ pad_idx: int,
177
+ beam_width: int = 5,
178
+ n_best: int = 5,
179
+ length_penalty: float = 0.6
180
+ ) -> list[str]:
181
+ """
182
+ Translates a single SMILES string using beam search.
183
+ (Code copied and pasted from test_ckpt.py)
184
+ """
185
+ model.eval() # Ensure model is in eval mode
186
+ translations = []
187
+
188
+ # --- Tokenize Source ---
189
+ try:
190
+ src_encoded = smiles_tokenizer.encode(src_sentence)
191
+ if not src_encoded or not src_encoded.ids:
192
+ logging.warning(f"Encoding failed or empty for SMILES: {src_sentence}")
193
+ return ["[Encoding Error]"] * n_best
194
+ src_ids = src_encoded.ids[:max_len] # Truncate source
195
+ if not src_ids:
196
+ logging.warning(f"Source empty after truncation: {src_sentence}")
197
+ return ["[Encoding Error - Empty Src]"] * n_best
198
+ except Exception as e:
199
+ logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
200
+ return ["[Encoding Error]"] * n_best
201
+
202
+ # --- Prepare Input Tensor and Mask ---
203
+ src = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) # [1, src_len]
204
+ src_padding_mask = (src == pad_idx).to(device) # [1, src_len]
205
+
206
+ # --- Perform Beam Search Decoding ---
207
+ # Calls the beam_search_decode function defined above in this file
208
+ tgt_tokens_list = beam_search_decode(
209
+ model=model,
210
+ src=src,
211
+ src_padding_mask=src_padding_mask,
212
+ max_len=max_len,
213
+ sos_idx=sos_idx,
214
+ eos_idx=eos_idx,
215
+ pad_idx=pad_idx,
216
+ device=device,
217
+ beam_width=beam_width,
218
+ n_best=n_best,
219
+ length_penalty=length_penalty
220
+ ) # Returns list of tensors
221
+
222
+ # --- Decode Generated Tokens ---
223
+ if not tgt_tokens_list:
224
+ logging.warning(f"Beam search returned empty list for SMILES: {src_sentence}")
225
+ return ["[Decoding Error - Empty Output]"] * n_best
226
+
227
+ for tgt_tokens_tensor in tgt_tokens_list:
228
+ if tgt_tokens_tensor.numel() > 0:
229
+ tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
230
+ try:
231
+ translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
232
+ translations.append(translation)
233
+ except Exception as e:
234
+ logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
235
+ translations.append("[Decoding Error]")
236
+ else:
237
+ translations.append("[Decoding Error - Empty Tensor]")
238
+
239
+ # Pad with error messages if fewer than n_best results were generated
240
+ while len(translations) < n_best:
241
+ translations.append("[Decoding Error - Fewer Results]")
242
+
243
+ return translations
244
+
245
+
246
+ # --- Model/Tokenizer Loading Function (Unchanged) ---
247
+ def load_model_and_tokenizers():
248
+ """Loads tokenizers, config, and model from Hugging Face Hub."""
249
+ global model, smiles_tokenizer, iupac_tokenizer, device, config
250
+ if model is not None: # Already loaded
251
+ logging.info("Model and tokenizers already loaded.")
252
+ return
253
+
254
+ logging.info(f"Starting model and tokenizer loading from {MODEL_REPO_ID}...")
255
+ try:
256
+ device = torch.device("cpu")
257
+ logging.info(f"Using device: {device}")
258
+
259
+ # Download files from HF Hub
260
+ logging.info("Downloading files from Hugging Face Hub...")
261
+ try:
262
+ checkpoint_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CHECKPOINT_FILENAME)
263
+ smiles_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=SMILES_TOKENIZER_FILENAME)
264
+ iupac_tokenizer_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=IUPAC_TOKENIZER_FILENAME)
265
+ config_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=CONFIG_FILENAME)
266
+ logging.info("Files downloaded successfully.")
267
+ except Exception as e:
268
+ logging.error(f"Failed to download files from {MODEL_REPO_ID}. Check filenames and repo status. Error: {e}", exc_info=True)
269
+ raise gr.Error(f"Download Error: Could not download required files from {MODEL_REPO_ID}. Check Space logs. Error: {e}")
270
+
271
+ # Load config
272
+ logging.info("Loading configuration...")
273
+ try:
274
+ with open(config_path, 'r') as f:
275
+ config = json.load(f)
276
+ logging.info("Configuration loaded.")
277
+ # --- Validate essential config keys ---
278
+ required_keys = [
279
+ 'src_vocab_size', 'tgt_vocab_size', 'emb_size', 'nhead',
280
+ 'ffn_hid_dim', 'num_encoder_layers', 'num_decoder_layers',
281
+ 'dropout', 'max_len', 'bos_token_id', 'eos_token_id', 'pad_token_id'
282
+ ]
283
+ missing_keys = [key for key in required_keys if key not in config]
284
+ if missing_keys:
285
+ raise ValueError(f"Config file '{CONFIG_FILENAME}' is missing required keys: {missing_keys}")
286
+ # --- End Validation ---
287
+ except FileNotFoundError:
288
+ logging.error(f"Config file not found locally after download attempt: {config_path}")
289
+ raise gr.Error(f"Config Error: Config file '{CONFIG_FILENAME}' not found. Check file exists in repo.")
290
+ except json.JSONDecodeError as e:
291
+ logging.error(f"Error decoding JSON from config file {config_path}: {e}")
292
+ raise gr.Error(f"Config Error: Could not parse '{CONFIG_FILENAME}'. Check its format. Error: {e}")
293
+ except ValueError as e:
294
+ logging.error(f"Config validation error: {e}")
295
+ raise gr.Error(f"Config Error: {e}")
296
+
297
+
298
+ # Load tokenizers
299
+ logging.info("Loading tokenizers...")
300
+ try:
301
+ smiles_tokenizer = Tokenizer.from_file(smiles_tokenizer_path)
302
+ iupac_tokenizer = Tokenizer.from_file(iupac_tokenizer_path)
303
+ logging.info("Tokenizers loaded.")
304
+ # --- Validate Tokenizer Special Tokens ---
305
+ # Add more robust checks if necessary
306
+ if smiles_tokenizer.token_to_id("<pad>") != config['pad_token_id'] or \
307
+ smiles_tokenizer.token_to_id("<unk>") is None:
308
+ logging.warning("SMILES tokenizer special tokens might not match config or are missing.")
309
+ if iupac_tokenizer.token_to_id("<pad>") != config['pad_token_id'] or \
310
+ iupac_tokenizer.token_to_id("<sos>") != config['bos_token_id'] or \
311
+ iupac_tokenizer.token_to_id("<eos>") != config['eos_token_id'] or \
312
+ iupac_tokenizer.token_to_id("<unk>") is None:
313
+ logging.warning("IUPAC tokenizer special tokens might not match config or are missing.")
314
+ # --- End Validation ---
315
+ except Exception as e:
316
+ logging.error(f"Failed to load tokenizers from {smiles_tokenizer_path} or {iupac_tokenizer_path}: {e}", exc_info=True)
317
+ raise gr.Error(f"Tokenizer Error: Could not load tokenizer files. Check Space logs. Error: {e}")
318
+
319
+ # Load model
320
+ logging.info("Loading model from checkpoint...")
321
+ try:
322
+ model = SmilesIupacLitModule.load_from_checkpoint(
323
+ checkpoint_path,
324
+ src_vocab_size=config['src_vocab_size'],
325
+ tgt_vocab_size=config['tgt_vocab_size'],
326
+ map_location=device,
327
+ hparams_dict=config,
328
+ strict=False,
329
+ device="cpu"
330
+ )
331
+ model.to(device)
332
+ model.eval()
333
+ model.freeze()
334
+ logging.info("Model loaded successfully, set to eval mode, frozen, and moved to device.")
335
+
336
+ except FileNotFoundError:
337
+ logging.error(f"Checkpoint file not found locally after download attempt: {checkpoint_path}")
338
+ raise gr.Error(f"Model Error: Checkpoint file '{CHECKPOINT_FILENAME}' not found.")
339
+ except Exception as e:
340
+ logging.error(f"Error loading model from checkpoint {checkpoint_path}: {e}", exc_info=True)
341
+ if "memory" in str(e).lower():
342
+ gc.collect()
343
+ if device == torch.device("cuda"):
344
+ torch.cuda.empty_cache()
345
+ raise gr.Error(f"Model Error: Failed to load model checkpoint. Check Space logs. Error: {e}")
346
+
347
+ except gr.Error:
348
+ raise
349
+ except Exception as e:
350
+ logging.error(f"Unexpected error during model/tokenizer loading: {e}", exc_info=True)
351
+ raise gr.Error(f"Initialization Error: An unexpected error occurred. Check Space logs. Error: {e}")
352
+
353
+
354
+ # --- Inference Function for Gradio (Unchanged, calls local translate) ---
355
+ def predict_iupac(smiles_string, beam_width, n_best, length_penalty):
356
+ """
357
+ Performs SMILES to IUPAC translation using the loaded model and beam search.
358
+ """
359
+ global model, smiles_tokenizer, iupac_tokenizer, device, config
360
+
361
+ if not all([model, smiles_tokenizer, iupac_tokenizer, device, config]):
362
+ error_msg = "Error: Model or tokenizers not loaded properly. Check Space logs."
363
+ # Ensure n_best is int for range, default to 1 if conversion fails early
364
+ try: n_best_int = int(n_best)
365
+ except: n_best_int = 1
366
+ return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)])
367
+
368
+ if not smiles_string or not smiles_string.strip():
369
+ error_msg = "Error: Please enter a valid SMILES string."
370
+ try: n_best_int = int(n_best)
371
+ except: n_best_int = 1
372
+ return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best_int)])
373
+
374
+ smiles_input = smiles_string.strip()
375
+ try:
376
+ beam_width = int(beam_width)
377
+ n_best = int(n_best)
378
+ length_penalty = float(length_penalty)
379
+ except ValueError as e:
380
+ error_msg = f"Error: Invalid input parameter type ({e})."
381
+ return f"1. {error_msg}" # Cannot determine n_best here
382
+
383
+ logging.info(f"Translating SMILES: '{smiles_input}' (Beam={beam_width}, N={n_best}, Penalty={length_penalty})")
384
+
385
+ try:
386
+ # Calls the translate function defined *above in this file*
387
+ predicted_names = translate(
388
+ model=model,
389
+ src_sentence=smiles_input,
390
+ smiles_tokenizer=smiles_tokenizer,
391
+ iupac_tokenizer=iupac_tokenizer,
392
+ device=device,
393
+ max_len=config['max_len'],
394
+ sos_idx=config['bos_token_id'],
395
+ eos_idx=config['eos_token_id'],
396
+ pad_idx=config['pad_token_id'],
397
+ beam_width=beam_width,
398
+ n_best=n_best,
399
+ length_penalty=length_penalty
400
+ )
401
+ logging.info(f"Predictions returned: {predicted_names}")
402
+
403
+ if not predicted_names:
404
+ output_text = f"Input SMILES: {smiles_input}\n\nNo predictions generated."
405
+ else:
406
+ output_text = f"Input SMILES: {smiles_input}\n\nTop {len(predicted_names)} Predictions (Beam Width={beam_width}, Length Penalty={length_penalty:.2f}):\n"
407
+ output_text += "\n".join([f"{i+1}. {name}" for i, name in enumerate(predicted_names)])
408
+
409
+ return output_text
410
+
411
+ except RuntimeError as e:
412
+ logging.error(f"Runtime error during translation: {e}", exc_info=True)
413
+ error_msg = f"Runtime Error during translation: {e}"
414
+ if "memory" in str(e).lower():
415
+ gc.collect()
416
+ if device == torch.device("cuda"):
417
+ torch.cuda.empty_cache()
418
+ error_msg += " (Potential OOM)"
419
+ return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)])
420
+
421
+ except Exception as e:
422
+ logging.error(f"Unexpected error during translation: {e}", exc_info=True)
423
+ error_msg = f"Unexpected Error during translation: {e}"
424
+ return "\n".join([f"{i+1}. {error_msg}" for i in range(n_best)])
425
+
426
+
427
+ # --- Load Model on App Start (Unchanged) ---
428
+ try:
429
+ load_model_and_tokenizers()
430
+ except gr.Error:
431
+ pass # Error already raised for Gradio UI
432
+ except Exception as e:
433
+ logging.error(f"Critical error during initial model loading sequence: {e}", exc_info=True)
434
+ gr.Error(f"Fatal Initialization Error: {e}. Check Space logs.")
435
+
436
+
437
+ # --- Create Gradio Interface (Unchanged) ---
438
+ title = "SMILES to IUPAC Name Translator"
439
+ description = f"""
440
+ Enter a SMILES string to translate it into its IUPAC chemical name using a Transformer model and beam search decoding.
441
+ Model repository: <a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>{MODEL_REPO_ID}</a>.
442
+ Adjust beam search parameters below. Higher beam width explores more possibilities but is slower. Length penalty influences the preference for shorter/longer names.
443
+ """
444
+
445
+ examples = [
446
+ ["CCO", 5, 3, 0.6],
447
+ ["C1=CC=CC=C1", 5, 3, 0.6],
448
+ ["CC(=O)Oc1ccccc1C(=O)O", 5, 3, 0.6], # Aspirin
449
+ ["CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", 5, 3, 0.6], # Ibuprofen
450
+ ["CC(=O)O[C@@H]1C[C@@H]2[C@]3(CCCC([C@@H]3CC[C@]2([C@H]4[C@]1([C@H]5[C@@H](OC(=O)C5=CC4)OC)C)C)(C)C)C", 5, 1, 0.6], # Complex example
451
+ ["INVALID_SMILES", 5, 1, 0.6],
452
+ ]
453
+
454
+ smiles_input = gr.Textbox(
455
+ label="SMILES String",
456
+ placeholder="Enter SMILES string here (e.g., CCO for Ethanol)",
457
+ lines=1
458
+ )
459
+ beam_width_input = gr.Slider(
460
+ minimum=1,
461
+ maximum=10,
462
+ value=5,
463
+ step=1,
464
+ label="Beam Width (k)",
465
+ info="Number of sequences to keep at each decoding step (higher = more exploration, slower)."
466
+ )
467
+ n_best_input = gr.Slider(
468
+ minimum=1,
469
+ maximum=10,
470
+ value=3,
471
+ step=1,
472
+ label="Number of Results (n_best)",
473
+ info="How many top-scoring sequences to return (must be <= Beam Width)."
474
+ )
475
+ length_penalty_input = gr.Slider(
476
+ minimum=0.0,
477
+ maximum=2.0,
478
+ value=0.6,
479
+ step=0.1,
480
+ label="Length Penalty (alpha)",
481
+ info="Controls preference for sequence length. >1 prefers longer, <1 prefers shorter, 0 no penalty."
482
+ )
483
+ output_text = gr.Textbox(
484
+ label="Predicted IUPAC Name(s)",
485
+ lines=5,
486
+ show_copy_button=True
487
+ )
488
+
489
+ iface = gr.Interface(
490
+ fn=predict_iupac,
491
+ inputs=[smiles_input, beam_width_input, n_best_input, length_penalty_input],
492
+ outputs=output_text,
493
+ title=title,
494
+ description=description,
495
+ examples=examples,
496
+ allow_flagging="never",
497
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan"),
498
+ article="Note: Translation quality depends on the training data and model size. Complex molecules might yield less accurate results."
499
+ )
500
+
501
+ # --- Launch the App (Unchanged) ---
502
+ if __name__ == "__main__":
503
+ iface.launch(share=True)
enhanced_trainer.py ADDED
@@ -0,0 +1,1421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import Transformer
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ import pytorch_lightning as pl # Import PyTorch Lightning
8
+ from pytorch_lightning.loggers import WandbLogger # Import WandbLogger
9
+ from pytorch_lightning.callbacks import (
10
+ ModelCheckpoint,
11
+ EarlyStopping,
12
+ ) # Import Callbacks
13
+ import math
14
+ import os
15
+ import pandas as pd
16
+ from sklearn.model_selection import train_test_split
17
+ import time
18
+ import wandb # Import wandb
19
+
20
+
21
+ from tokenizers import (
22
+ Tokenizer,
23
+ models,
24
+ pre_tokenizers,
25
+ decoders,
26
+ trainers,
27
+ )
28
+
29
+ import logging
30
+ import gc
31
+
32
+ # --- Basic Logging Setup ---
33
+ logging.basicConfig(
34
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
35
+ )
36
+
37
+ # --- 1. Configuration & Hyperparameters ---
38
+
39
+ # Model Hyperparameters (Scaled up for H100s - ADJUST AS NEEDED based on memory)
40
+ # Note: BPE might benefit from a slightly larger vocab size than the regex approach
41
+ SRC_VOCAB_SIZE_ESTIMATE = 10000 # Increased estimate for SMILES BPE
42
+ TGT_VOCAB_SIZE_ESTIMATE = 14938 # Increased estimate for IUPAC
43
+ EMB_SIZE = 2048 # Embedding dimension (d_model) - Increased significantly
44
+ NHEAD = 8 # Number of attention heads (must divide EMB_SIZE) - Increased
45
+ FFN_HID_DIM = (
46
+ 4096 # Feedforward network hidden dimension (e.g., 4 * EMB_SIZE) - Increased
47
+ )
48
+ NUM_ENCODER_LAYERS = 12 # Number of layers in Encoder - Increased
49
+ NUM_DECODER_LAYERS = 12 # Number of layers in Decoder - Increased
50
+ DROPOUT = 0.1 # Dropout rate (can sometimes be reduced slightly for larger models)
51
+ MAX_LEN = 384 # Maximum sequence length (consider increasing if needed/possible)
52
+
53
+ # Training Hyperparameters
54
+ ACCELERATOR = "gpu"
55
+ DEVICES = 6 # Number of H100 GPUs to use
56
+ STRATEGY = "ddp" # Distributed Data Parallel Strategy
57
+ PRECISION = "16-mixed" # Use mixed precision for speed and memory saving on H100s
58
+ BATCH_SIZE_PER_GPU = 48 # Adjust based on H100 GPU memory (e.g., 32, 48, 64) - Effective BS = BATCH_SIZE_PER_GPU * DEVICES
59
+ ACCUMULATE_GRAD_BATCHES = (
60
+ 1 # Increase if BATCH_SIZE_PER_GPU needs to be smaller due to memory
61
+ )
62
+ NUM_EPOCHS = 50 # Increase for potentially longer training needed for larger models
63
+ LEARNING_RATE = 5e-5 # Might need adjustment for larger models/batch sizes
64
+ WEIGHT_DECAY = 1e-2
65
+ GRAD_CLIP_NORM = 1.0
66
+ VALIDATION_SPLIT = 0.05 # Use a smaller validation split if the dataset is huge
67
+ RANDOM_SEED = 42
68
+ PATIENCE = 5 # Early stopping patience
69
+ NUM_WORKERS = 8 # Adjust based on CPU cores and system capabilities
70
+
71
+ # Special Token Indices
72
+ PAD_IDX = 0
73
+ SOS_IDX = 1
74
+ EOS_IDX = 2
75
+ UNK_IDX = 3
76
+
77
+ # File Paths
78
+ # *** CHANGED SMILES TOKENIZER FILENAME ***
79
+ SMILES_TOKENIZER_FILE = "smiles_bytelevel_bpe_tokenizer_scaled.json"
80
+ IUPAC_TOKENIZER_FILE = "iupac_unigram_tokenizer_scaled.json"
81
+ INPUT_CSV_FILE = "data_clean.csv" # <--- Your input CSV file path
82
+
83
+ # Output files for data splits
84
+ TRAIN_SMILES_FILE = "train.smi"
85
+ TRAIN_IUPAC_FILE = "train.iupac"
86
+ VAL_SMILES_FILE = "val.smi"
87
+ VAL_IUPAC_FILE = "val.iupac"
88
+ CHECKPOINT_DIR = "checkpoints" # Directory to save model checkpoints
89
+ BEST_MODEL_FILENAME = (
90
+ "smiles-to-iupac-transformer-best" # Filename format for checkpoints
91
+ )
92
+
93
+ # WandB Configuration
94
+ WANDB_PROJECT = "SMILES-to-IUPAC-Large-BPE" # Updated project name slightly
95
+ WANDB_ENTITY = (
96
+ "adrianmirza" # Replace with your WandB entity (username or team name) if desired
97
+ )
98
+ WANDB_RUN_NAME = f"transformer_BPE_E{EMB_SIZE}_H{NHEAD}_L{NUM_ENCODER_LAYERS}_BS{BATCH_SIZE_PER_GPU * DEVICES}_LR{LEARNING_RATE}"
99
+
100
+ # Store hparams for logging
101
+ hparams = {
102
+ "src_tokenizer_type": "ByteLevelBPE", # Added tokenizer type info
103
+ "tgt_tokenizer_type": "Unigram",
104
+ "src_vocab_size_estimate": SRC_VOCAB_SIZE_ESTIMATE,
105
+ "tgt_vocab_size_estimate": TGT_VOCAB_SIZE_ESTIMATE,
106
+ "emb_size": EMB_SIZE,
107
+ "nhead": NHEAD,
108
+ "ffn_hid_dim": FFN_HID_DIM,
109
+ "num_encoder_layers": NUM_ENCODER_LAYERS,
110
+ "num_decoder_layers": NUM_DECODER_LAYERS,
111
+ "dropout": DROPOUT,
112
+ "max_len": MAX_LEN,
113
+ "batch_size_per_gpu": BATCH_SIZE_PER_GPU,
114
+ "effective_batch_size": BATCH_SIZE_PER_GPU * DEVICES * ACCUMULATE_GRAD_BATCHES,
115
+ "num_epochs": NUM_EPOCHS,
116
+ "learning_rate": LEARNING_RATE,
117
+ "weight_decay": WEIGHT_DECAY,
118
+ "grad_clip_norm": GRAD_CLIP_NORM,
119
+ "validation_split": VALIDATION_SPLIT,
120
+ "random_seed": RANDOM_SEED,
121
+ "patience": PATIENCE,
122
+ "precision": PRECISION,
123
+ "gpus": DEVICES,
124
+ "strategy": STRATEGY,
125
+ "num_workers": NUM_WORKERS,
126
+ }
127
+
128
+ # --- 2. Token izers (Modified SMILES Tokenizer) ---
129
+
130
+
131
+ # --- 2.a SMILES ByteLevel BPE Tokenizer (Replaced WordLevel Regex) ---
132
+ def get_smiles_tokenizer(
133
+ train_files=None,
134
+ vocab_size=30000,
135
+ min_frequency=2,
136
+ tokenizer_path=SMILES_TOKENIZER_FILE,
137
+ ):
138
+ """Creates or loads a Byte-Level BPE tokenizer for SMILES."""
139
+ if os.path.exists(tokenizer_path):
140
+ logging.info(f"Loading existing SMILES tokenizer from {tokenizer_path}")
141
+ try:
142
+ tokenizer = Tokenizer.from_file(tokenizer_path)
143
+ # Verify special tokens after loading
144
+ if (
145
+ tokenizer.token_to_id("<pad>") != PAD_IDX
146
+ or tokenizer.token_to_id("<sos>") != SOS_IDX
147
+ or tokenizer.token_to_id("<eos>") != EOS_IDX
148
+ or tokenizer.token_to_id("<unk>") != UNK_IDX
149
+ ):
150
+ logging.warning(
151
+ "Special token ID mismatch after loading SMILES tokenizer. Re-check config."
152
+ )
153
+ # Check if it's actually a BPE model (basic check)
154
+ if not isinstance(tokenizer.model, models.BPE):
155
+ logging.warning(
156
+ f"Loaded tokenizer from {tokenizer_path} is not a BPE model. Retraining."
157
+ )
158
+ raise TypeError("Incorrect tokenizer model type loaded.")
159
+ return tokenizer
160
+ except Exception as e:
161
+ logging.error(f"Failed to load SMILES tokenizer: {e}. Retraining...")
162
+
163
+ logging.info("Creating and training SMILES Byte-Level BPE tokenizer...")
164
+ # Use BPE model
165
+ tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
166
+
167
+ # Use ByteLevel pre-tokenizer - this handles any character sequence
168
+ # add_prefix_space=False is generally suitable for SMILES as it doesn't rely on spaces
169
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
170
+ # Use ByteLevel decoder
171
+ tokenizer.decoder = decoders.ByteLevel()
172
+
173
+ special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
174
+ # Use BpeTrainer
175
+ trainer = trainers.BpeTrainer(
176
+ vocab_size=vocab_size,
177
+ min_frequency=min_frequency,
178
+ special_tokens=special_tokens,
179
+ # BPE specific options can be added here if needed, e.g.:
180
+ # initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), # Usually inferred
181
+ # show_progress=True,
182
+ )
183
+
184
+ if train_files and all(os.path.exists(f) for f in train_files):
185
+ logging.info(f"Training SMILES BPE tokenizer on: {train_files}")
186
+ tokenizer.train(files=train_files, trainer=trainer)
187
+ logging.info(
188
+ f"SMILES BPE tokenizer trained. Final Vocab size: {tokenizer.get_vocab_size()}"
189
+ )
190
+ # Verify special token IDs after training
191
+ if (
192
+ tokenizer.token_to_id("<pad>") != PAD_IDX
193
+ or tokenizer.token_to_id("<sos>") != SOS_IDX
194
+ or tokenizer.token_to_id("<eos>") != EOS_IDX
195
+ or tokenizer.token_to_id("<unk>") != UNK_IDX
196
+ ):
197
+ logging.warning(
198
+ "Special token ID mismatch after training SMILES BPE tokenizer. Check trainer setup."
199
+ )
200
+ try:
201
+ tokenizer.save(tokenizer_path)
202
+ logging.info(f"SMILES BPE tokenizer saved to {tokenizer_path}")
203
+ except Exception as e:
204
+ logging.error(f"Failed to save SMILES BPE tokenizer: {e}")
205
+ else:
206
+ logging.error(
207
+ "Training files not provided or not found for SMILES tokenizer. Cannot train."
208
+ )
209
+ # Manually add special tokens if training fails, so basic encoding/decoding might work
210
+ tokenizer.add_special_tokens(special_tokens)
211
+
212
+ return tokenizer
213
+
214
+
215
+ # --- 2.b IUPAC Unigram Tokenizer (No changes needed here) ---
216
+ def get_iupac_tokenizer(
217
+ train_files=None,
218
+ vocab_size=30000,
219
+ min_frequency=2,
220
+ tokenizer_path=IUPAC_TOKENIZER_FILE,
221
+ ):
222
+ """Creates or loads a Unigram tokenizer for IUPAC names."""
223
+ if os.path.exists(tokenizer_path):
224
+ logging.info(f"Loading existing IUPAC tokenizer from {tokenizer_path}")
225
+ try:
226
+ tokenizer = Tokenizer.from_file(tokenizer_path)
227
+ if (
228
+ tokenizer.token_to_id("<pad>") != PAD_IDX
229
+ or tokenizer.token_to_id("<sos>") != SOS_IDX
230
+ or tokenizer.token_to_id("<eos>") != EOS_IDX
231
+ or tokenizer.token_to_id("<unk>") != UNK_IDX
232
+ ):
233
+ logging.warning(
234
+ "Special token ID mismatch after loading IUPAC tokenizer. Re-check config."
235
+ )
236
+ return tokenizer
237
+ except Exception as e:
238
+ logging.error(f"Failed to load IUPAC tokenizer: {e}. Retraining...")
239
+
240
+ logging.info("Creating and training IUPAC Unigram tokenizer...")
241
+ tokenizer = Tokenizer(models.Unigram())
242
+ # Using Sequence of pre-tokenizers for IUPAC is reasonable
243
+ pre_tokenizer_list = [
244
+ pre_tokenizers.WhitespaceSplit(), # Split by whitespace first
245
+ pre_tokenizers.Punctuation(), # Split punctuation
246
+ pre_tokenizers.Digits(individual_digits=True), # Split digits
247
+ ]
248
+ # Consider adding Metaspace if Unigram struggles with word boundaries after splits
249
+ # tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() # Alternative
250
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(pre_tokenizer_list)
251
+ tokenizer.decoder = (
252
+ decoders.Metaspace()
253
+ ) # Metaspace decoder often works well with Unigram/BPE
254
+ special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
255
+ trainer = trainers.UnigramTrainer(
256
+ vocab_size=vocab_size,
257
+ special_tokens=special_tokens,
258
+ unk_token="<unk>",
259
+ # Unigram specific options can be added here
260
+ # shrinking_factor=0.75,
261
+ # n_sub_iterations=2,
262
+ )
263
+
264
+ if train_files and all(os.path.exists(f) for f in train_files):
265
+ logging.info(f"Training IUPAC tokenizer on: {train_files}")
266
+ tokenizer.train(files=train_files, trainer=trainer)
267
+ logging.info(
268
+ f"IUPAC tokenizer trained. Final Vocab size: {tokenizer.get_vocab_size()}"
269
+ )
270
+ # Verify special token IDs after training
271
+ if (
272
+ tokenizer.token_to_id("<pad>") != PAD_IDX
273
+ or tokenizer.token_to_id("<sos>") != SOS_IDX
274
+ or tokenizer.token_to_id("<eos>") != EOS_IDX
275
+ or tokenizer.token_to_id("<unk>") != UNK_IDX
276
+ ):
277
+ logging.warning(
278
+ "Special token ID mismatch after training IUPAC tokenizer. Check trainer setup."
279
+ )
280
+ try:
281
+ tokenizer.save(tokenizer_path)
282
+ logging.info(f"IUPAC tokenizer saved to {tokenizer_path}")
283
+ except Exception as e:
284
+ logging.error(f"Failed to save IUPAC tokenizer: {e}")
285
+ else:
286
+ logging.error(
287
+ "Training files not provided or not found for IUPAC tokenizer. Cannot train."
288
+ )
289
+ tokenizer.add_special_tokens(special_tokens)
290
+
291
+ return tokenizer
292
+
293
+
294
+ # --- 3. Model Definition (No changes needed) ---
295
+ class PositionalEncoding(nn.Module):
296
+ """Injects positional information into the input embeddings."""
297
+
298
+ def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
299
+ super().__init__()
300
+ den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
301
+ pos = torch.arange(0, maxlen).reshape(maxlen, 1)
302
+ pos_embedding = torch.zeros((maxlen, emb_size))
303
+ pos_embedding[:, 0::2] = torch.sin(pos * den)
304
+ pos_embedding[:, 1::2] = torch.cos(pos * den)
305
+ pos_embedding = pos_embedding.unsqueeze(
306
+ 0
307
+ ) # Add batch dimension for broadcasting
308
+ self.dropout = nn.Dropout(dropout)
309
+ self.register_buffer(
310
+ "pos_embedding", pos_embedding
311
+ ) # Shape [1, maxlen, emb_size]
312
+
313
+ def forward(self, token_embedding: torch.Tensor):
314
+ # token_embedding: Expected shape [batch_size, seq_len, emb_size]
315
+ seq_len = token_embedding.size(1)
316
+ # Slicing pos_embedding: [1, seq_len, emb_size]
317
+ # Handle cases where seq_len might exceed buffer's maxlen during inference/edge cases
318
+ if seq_len > self.pos_embedding.size(1):
319
+ logging.warning(
320
+ f"Input sequence length ({seq_len}) exceeds PositionalEncoding maxlen ({self.pos_embedding.size(1)}). Truncating positional encoding."
321
+ )
322
+ pos_to_add = self.pos_embedding[:, : self.pos_embedding.size(1), :]
323
+ # Pad token_embedding if needed? Or error out? For now, just use available encoding.
324
+ # This scenario shouldn't happen if MAX_LEN config is respected.
325
+ output = token_embedding[:, : self.pos_embedding.size(1), :] + pos_to_add
326
+ else:
327
+ pos_to_add = self.pos_embedding[:, :seq_len, :]
328
+ output = token_embedding + pos_to_add
329
+
330
+ return self.dropout(output)
331
+
332
+
333
+ class TokenEmbedding(nn.Module):
334
+ """Converts token indices to embeddings."""
335
+
336
+ def __init__(self, vocab_size: int, emb_size):
337
+ super().__init__()
338
+ self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=PAD_IDX)
339
+ self.emb_size = emb_size
340
+
341
+ def forward(self, tokens: torch.Tensor):
342
+ return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
343
+
344
+
345
+ class Seq2SeqTransformer(nn.Module):
346
+ """The main Encoder-Decoder Transformer model."""
347
+
348
+ def __init__(
349
+ self,
350
+ num_encoder_layers: int,
351
+ num_decoder_layers: int,
352
+ emb_size: int,
353
+ nhead: int,
354
+ src_vocab_size: int,
355
+ tgt_vocab_size: int,
356
+ dim_feedforward: int,
357
+ dropout: float = 0.1,
358
+ max_len: int = MAX_LEN,
359
+ ): # Use MAX_LEN from config
360
+ super().__init__()
361
+
362
+ if emb_size % nhead != 0:
363
+ raise ValueError(
364
+ f"Embedding size ({emb_size}) must be divisible by the number of heads ({nhead})"
365
+ )
366
+
367
+ self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
368
+ self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
369
+
370
+ # Ensure PositionalEncoding maxlen is sufficient
371
+ pe_maxlen = max(
372
+ max_len, 5000
373
+ ) # Use config MAX_LEN or default 5000, whichever is larger
374
+ self.positional_encoding = PositionalEncoding(
375
+ emb_size, dropout=dropout, maxlen=pe_maxlen
376
+ )
377
+
378
+ self.transformer = Transformer(
379
+ d_model=emb_size,
380
+ nhead=nhead,
381
+ num_encoder_layers=num_encoder_layers,
382
+ num_decoder_layers=num_decoder_layers,
383
+ dim_feedforward=dim_feedforward,
384
+ dropout=dropout,
385
+ batch_first=True,
386
+ ) # Use batch_first=True
387
+
388
+ self.generator = nn.Linear(emb_size, tgt_vocab_size)
389
+ self._init_weights()
390
+
391
+ def _init_weights(self):
392
+ for p in self.parameters():
393
+ if p.dim() > 1:
394
+ nn.init.xavier_uniform_(p)
395
+
396
+ def forward(
397
+ self,
398
+ src: torch.Tensor, # Input sequence (batch_size, src_len)
399
+ trg: torch.Tensor, # Target sequence (batch_size, tgt_len)
400
+ tgt_mask: torch.Tensor, # Target causal mask (tgt_len, tgt_len)
401
+ src_padding_mask: torch.Tensor, # Source padding mask (batch_size, src_len)
402
+ tgt_padding_mask: torch.Tensor, # Target padding mask (batch_size, tgt_len)
403
+ memory_key_padding_mask: torch.Tensor,
404
+ ): # Memory padding mask (batch_size, src_len)
405
+ # --- Ensure masks have correct dtype and device ---
406
+ # Pytorch Transformer expects boolean masks where True indicates masking
407
+ src_padding_mask = src_padding_mask.to(src.device)
408
+ tgt_padding_mask = tgt_padding_mask.to(trg.device)
409
+ memory_key_padding_mask = memory_key_padding_mask.to(src.device)
410
+ # tgt_mask needs to be float for '-inf' filling, keep on target device
411
+ tgt_mask = tgt_mask.to(trg.device)
412
+
413
+ src_emb = self.positional_encoding(
414
+ self.src_tok_emb(src)
415
+ ) # [batch, src_len, dim]
416
+ tgt_emb = self.positional_encoding(
417
+ self.tgt_tok_emb(trg)
418
+ ) # [batch, tgt_len, dim]
419
+
420
+ outs = self.transformer(
421
+ src=src_emb,
422
+ tgt=tgt_emb,
423
+ src_mask=None, # Not typically needed for encoder unless custom masking
424
+ tgt_mask=tgt_mask, # Causal mask for decoder self-attn
425
+ memory_mask=None, # Not typically needed unless masking specific memory parts
426
+ src_key_padding_mask=src_padding_mask, # Mask padding in src K,V
427
+ tgt_key_padding_mask=tgt_padding_mask, # Mask padding in tgt Q
428
+ memory_key_padding_mask=memory_key_padding_mask,
429
+ ) # Mask padding in memory K,V for cross-attn
430
+ # outs: [batch_size, tgt_len, emb_size]
431
+ return self.generator(outs) # [batch_size, tgt_len, tgt_vocab_size]
432
+
433
+ def encode(self, src: torch.Tensor, src_padding_mask: torch.Tensor):
434
+ src_padding_mask = src_padding_mask.to(
435
+ src.device
436
+ ) # Ensure mask is on correct device
437
+ src_emb = self.positional_encoding(
438
+ self.src_tok_emb(src)
439
+ ) # [batch, src_len, dim]
440
+ memory = self.transformer.encoder(
441
+ src_emb, mask=None, src_key_padding_mask=src_padding_mask
442
+ )
443
+ return memory # Returns memory: [batch_size, src_len, emb_size]
444
+
445
+ def decode(
446
+ self,
447
+ tgt: torch.Tensor,
448
+ memory: torch.Tensor,
449
+ tgt_mask: torch.Tensor,
450
+ tgt_padding_mask: torch.Tensor,
451
+ memory_key_padding_mask: torch.Tensor,
452
+ ):
453
+ # Ensure masks are on correct device
454
+ tgt_mask = tgt_mask.to(tgt.device)
455
+ tgt_padding_mask = tgt_padding_mask.to(tgt.device)
456
+ memory_key_padding_mask = memory_key_padding_mask.to(memory.device)
457
+
458
+ tgt_emb = self.positional_encoding(
459
+ self.tgt_tok_emb(tgt)
460
+ ) # [batch, tgt_len, dim]
461
+ output = self.transformer.decoder(
462
+ tgt=tgt_emb,
463
+ memory=memory,
464
+ tgt_mask=tgt_mask,
465
+ memory_mask=None,
466
+ tgt_key_padding_mask=tgt_padding_mask,
467
+ memory_key_padding_mask=memory_key_padding_mask,
468
+ )
469
+ return output # Returns decoder output: [batch_size, tgt_len, emb_size]
470
+
471
+
472
+ # --- Helper function for mask creation (No changes needed) ---
473
+ def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
474
+ """Generates an upper-triangular matrix for causal masking."""
475
+ mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
476
+ mask = (
477
+ mask.float()
478
+ .masked_fill(mask == 0, float("-inf"))
479
+ .masked_fill(mask == 1, float(0.0))
480
+ )
481
+ return mask # Shape [sz, sz]
482
+
483
+
484
+ def create_masks(
485
+ src: torch.Tensor, tgt: torch.Tensor, pad_idx: int, device: torch.device
486
+ ):
487
+ """
488
+ Creates all necessary masks for the Transformer model.
489
+ Assumes src and tgt are inputs to the forward pass (tgt includes SOS, excludes EOS).
490
+ Returns boolean masks where True indicates the position should be masked (ignored).
491
+ """
492
+ src_seq_len = src.shape[1]
493
+ tgt_seq_len = tgt.shape[1]
494
+
495
+ # Causal mask for decoder self-attention (float mask for PyTorch Transformer)
496
+ tgt_mask = generate_square_subsequent_mask(
497
+ tgt_seq_len, device
498
+ ) # [tgt_len, tgt_len]
499
+
500
+ # Padding masks (boolean, True where padded)
501
+ src_padding_mask = src == pad_idx # [batch_size, src_len]
502
+ tgt_padding_mask = tgt == pad_idx # [batch_size, tgt_len]
503
+ memory_key_padding_mask = (
504
+ src_padding_mask # Used in decoder cross-attention [batch_size, src_len]
505
+ )
506
+
507
+ return tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask
508
+
509
+
510
+ # --- 4. Data Handling (Dataset and Collate Function - No changes needed) ---
511
+ class SmilesIupacDataset(Dataset):
512
+ """Dataset class for SMILES-IUPAC pairs, reading from pre-split files."""
513
+
514
+ def __init__(self, smiles_file: str, iupac_file: str):
515
+ logging.info(f"Loading data from {smiles_file} and {iupac_file}")
516
+ try:
517
+ with open(smiles_file, "r", encoding="utf-8") as f_smi:
518
+ self.smiles = [line.strip() for line in f_smi if line.strip()]
519
+ with open(iupac_file, "r", encoding="utf-8") as f_iupac:
520
+ self.iupac = [line.strip() for line in f_iupac if line.strip()]
521
+
522
+ if len(self.smiles) != len(self.iupac):
523
+ logging.warning(
524
+ f"Mismatch in number of lines: {smiles_file} ({len(self.smiles)}) vs {iupac_file} ({len(self.iupac)}). Trimming."
525
+ )
526
+ min_len = min(len(self.smiles), len(self.iupac))
527
+ self.smiles = self.smiles[:min_len]
528
+ self.iupac = self.iupac[:min_len]
529
+
530
+ logging.info(
531
+ f"Loaded {len(self.smiles)} pairs from {smiles_file}/{iupac_file}."
532
+ )
533
+ if len(self.smiles) == 0:
534
+ logging.warning(f"Loaded 0 data pairs. Check files.")
535
+
536
+ except FileNotFoundError:
537
+ logging.error(
538
+ f"Error: One or both files not found: {smiles_file}, {iupac_file}"
539
+ )
540
+ raise
541
+ except Exception as e:
542
+ logging.error(f"Error loading data: {e}")
543
+ raise
544
+
545
+ def __len__(self):
546
+ return len(self.smiles)
547
+
548
+ def __getitem__(self, idx):
549
+ return self.smiles[idx], self.iupac[idx]
550
+
551
+
552
+ def collate_fn(
553
+ batch, smiles_tokenizer, iupac_tokenizer, pad_idx, sos_idx, eos_idx, max_len
554
+ ):
555
+ """Collates data samples into batches."""
556
+ src_batch, tgt_batch = [], []
557
+ skipped_count = 0
558
+ for src_sample, tgt_sample in batch:
559
+ try:
560
+ # Encode source (SMILES)
561
+ src_encoded = smiles_tokenizer.encode(src_sample)
562
+ # Truncate source if needed (including potential special tokens if added by encode)
563
+ src_ids = src_encoded.ids[:max_len]
564
+ if not src_ids: # Skip if encoding results in empty sequence
565
+ skipped_count += 1
566
+ continue
567
+ src_tensor = torch.tensor(src_ids, dtype=torch.long)
568
+
569
+ # Encode target (IUPAC)
570
+ tgt_encoded = iupac_tokenizer.encode(tgt_sample)
571
+ # Truncate target allowing space for SOS and EOS
572
+ tgt_ids = tgt_encoded.ids[: max_len - 2]
573
+ if (
574
+ not tgt_ids
575
+ ): # Skip if encoding results in empty sequence (after truncation)
576
+ skipped_count += 1
577
+ continue
578
+ # Add SOS and EOS tokens
579
+ tgt_tensor = torch.tensor([sos_idx] + tgt_ids + [eos_idx], dtype=torch.long)
580
+
581
+ src_batch.append(src_tensor)
582
+ tgt_batch.append(tgt_tensor)
583
+ except Exception as e:
584
+ # Log infrequent warnings for skipping
585
+ # if skipped_count < 5: # Log only the first few skips per batch
586
+ # logging.warning(f"Skipping sample due to error during tokenization/tensor creation: {e}. SMILES: '{src_sample[:50]}...', IUPAC: '{tgt_sample[:50]}...'")
587
+ skipped_count += 1
588
+ continue
589
+
590
+ # if skipped_count > 0:
591
+ # logging.debug(f"Skipped {skipped_count} samples in this batch during collation.")
592
+
593
+ if not src_batch or not tgt_batch:
594
+ # Return empty tensors if the whole batch was skipped
595
+ return torch.tensor([]), torch.tensor([])
596
+
597
+ try:
598
+ # Pad sequences
599
+ src_batch_padded = pad_sequence(
600
+ src_batch, batch_first=True, padding_value=pad_idx
601
+ )
602
+ tgt_batch_padded = pad_sequence(
603
+ tgt_batch, batch_first=True, padding_value=pad_idx
604
+ )
605
+ except Exception as e:
606
+ logging.error(
607
+ f"Error during padding: {e}. Src lengths: {[len(s) for s in src_batch]}, Tgt lengths: {[len(t) for t in tgt_batch]}"
608
+ )
609
+ # Return empty tensors on padding error
610
+ return torch.tensor([]), torch.tensor([])
611
+
612
+ return src_batch_padded, tgt_batch_padded
613
+
614
+
615
+ # --- 5. PyTorch Lightning Module (No changes needed) ---
616
+ class SmilesIupacLitModule(pl.LightningModule):
617
+ def __init__(
618
+ self, src_vocab_size: int, tgt_vocab_size: int, hparams_dict: dict
619
+ ): # Pass hparams dictionary
620
+ super().__init__()
621
+ # Use save_hyperparameters() to automatically save args to self.hparams
622
+ # and make them accessible in checkpoints and loggers
623
+ self.save_hyperparameters(hparams_dict)
624
+
625
+ self.model = Seq2SeqTransformer(
626
+ num_encoder_layers=self.hparams.num_encoder_layers,
627
+ num_decoder_layers=self.hparams.num_decoder_layers,
628
+ emb_size=self.hparams.emb_size,
629
+ nhead=self.hparams.nhead,
630
+ src_vocab_size=src_vocab_size, # Pass actual vocab size
631
+ tgt_vocab_size=tgt_vocab_size, # Pass actual vocab size
632
+ dim_feedforward=self.hparams.ffn_hid_dim,
633
+ dropout=self.hparams.dropout,
634
+ max_len=self.hparams.max_len, # Pass max_len here
635
+ )
636
+
637
+ self.criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
638
+
639
+ # --- Count Parameters --- (Done once at initialization)
640
+ total_params = sum(p.numel() for p in self.model.parameters())
641
+ trainable_params = sum(
642
+ p.numel() for p in self.model.parameters() if p.requires_grad
643
+ )
644
+ logging.info(f"Model Initialized:")
645
+ logging.info(f" Total Parameters: {total_params / 1_000_000:.2f} M")
646
+ logging.info(f" Trainable Parameters: {trainable_params / 1_000_000:.2f} M")
647
+ # Log params to wandb hparams if logger is available
648
+ # self.hparams are automatically logged by WandbLogger if passed to Trainer
649
+ # We can add them explicitly if needed, but save_hyperparameters usually handles it.
650
+ self.hparams.total_params_M = round(total_params / 1_000_000, 2)
651
+ self.hparams.trainable_params_M = round(trainable_params / 1_000_000, 2)
652
+
653
+ def forward(self, src, tgt):
654
+ # This is the main forward pass used for inference/prediction if needed
655
+ # For training/validation, we call the model directly in step methods
656
+ # to handle mask creation explicitly.
657
+ tgt_input = tgt[:, :-1] # Prepare target input (remove EOS)
658
+ tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = (
659
+ create_masks(
660
+ src,
661
+ tgt_input,
662
+ PAD_IDX,
663
+ self.device, # Use self.device provided by Lightning
664
+ )
665
+ )
666
+ logits = self.model(
667
+ src,
668
+ tgt_input,
669
+ tgt_mask,
670
+ src_padding_mask,
671
+ tgt_padding_mask,
672
+ memory_key_padding_mask,
673
+ )
674
+ return logits
675
+
676
+ def training_step(self, batch, batch_idx):
677
+ src, tgt = batch
678
+ if src.numel() == 0 or tgt.numel() == 0:
679
+ # logging.debug(f"Skipping empty batch {batch_idx} in training.")
680
+ return None # Skip empty batches
681
+
682
+ tgt_input = tgt[:, :-1] # Exclude EOS for input
683
+ tgt_out = tgt[:, 1:] # Exclude SOS for target labels
684
+
685
+ # Create masks on the current device
686
+ tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = (
687
+ create_masks(src, tgt_input, PAD_IDX, self.device)
688
+ )
689
+
690
+ try:
691
+ logits = self.model(
692
+ src=src,
693
+ trg=tgt_input,
694
+ tgt_mask=tgt_mask,
695
+ src_padding_mask=src_padding_mask,
696
+ tgt_padding_mask=tgt_padding_mask,
697
+ memory_key_padding_mask=memory_key_padding_mask,
698
+ )
699
+ # logits: [batch_size, tgt_len-1, tgt_vocab_size]
700
+
701
+ # Calculate loss
702
+ # Reshape logits to [batch_size * (tgt_len-1), tgt_vocab_size]
703
+ # Reshape tgt_out to [batch_size * (tgt_len-1)]
704
+ loss = self.criterion(
705
+ logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)
706
+ )
707
+
708
+ # Check for NaN/Inf loss (important with mixed precision)
709
+ if not torch.isfinite(loss):
710
+ logging.warning(
711
+ f"Non-finite loss encountered in training step {batch_idx}: {loss.item()}. Skipping update."
712
+ )
713
+ # Manually skip optimizer step if using manual optimization,
714
+ # otherwise returning None might be sufficient for automatic opt.
715
+ return None # Returning None should prevent optimizer step
716
+
717
+ # Log training loss
718
+ # sync_dist=True is important for DDP to average loss across GPUs
719
+ self.log(
720
+ "train_loss",
721
+ loss,
722
+ on_step=True,
723
+ on_epoch=True,
724
+ prog_bar=True,
725
+ logger=True,
726
+ sync_dist=True,
727
+ batch_size=src.size(0),
728
+ )
729
+
730
+ return loss
731
+
732
+ except RuntimeError as e:
733
+ if "CUDA out of memory" in str(e):
734
+ logging.warning(
735
+ f"CUDA OOM error during training step {batch_idx} with shape src: {src.shape}, tgt: {tgt.shape}. Skipping batch."
736
+ )
737
+ gc.collect()
738
+ torch.cuda.empty_cache()
739
+ return None # Skip update
740
+ else:
741
+ logging.error(f"Runtime error during training step {batch_idx}: {e}")
742
+ # Optionally log shapes for debugging other runtime errors
743
+ logging.error(f"Shapes - src: {src.shape}, tgt: {tgt.shape}")
744
+ return None # Skip update
745
+
746
+ def validation_step(self, batch, batch_idx):
747
+ src, tgt = batch
748
+ if src.numel() == 0 or tgt.numel() == 0:
749
+ # logging.debug(f"Skipping empty batch {batch_idx} in validation.")
750
+ return None
751
+
752
+ tgt_input = tgt[:, :-1]
753
+ tgt_out = tgt[:, 1:]
754
+
755
+ tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = (
756
+ create_masks(src, tgt_input, PAD_IDX, self.device)
757
+ )
758
+
759
+ try:
760
+ logits = self.model(
761
+ src,
762
+ tgt_input,
763
+ tgt_mask,
764
+ src_padding_mask,
765
+ tgt_padding_mask,
766
+ memory_key_padding_mask,
767
+ )
768
+ loss = self.criterion(
769
+ logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)
770
+ )
771
+
772
+ if torch.isfinite(loss):
773
+ # Log validation loss (accumulated across batches and synced across GPUs at epoch end)
774
+ # sync_dist=True ensures correct aggregation in DDP
775
+ self.log(
776
+ "val_loss",
777
+ loss,
778
+ on_step=False,
779
+ on_epoch=True,
780
+ prog_bar=True,
781
+ logger=True,
782
+ sync_dist=True,
783
+ batch_size=src.size(0),
784
+ )
785
+ else:
786
+ logging.warning(
787
+ f"Non-finite loss encountered during validation step {batch_idx}: {loss.item()}."
788
+ )
789
+ # PTL aggregates logged values automatically for the epoch
790
+ # Returning the loss value itself isn't strictly necessary when using self.log
791
+ # return loss
792
+
793
+ except RuntimeError as e:
794
+ # Don't crash validation if one batch fails (e.g., OOM on a particularly long sequence)
795
+ logging.error(f"Runtime error during validation step {batch_idx}: {e}")
796
+ if "CUDA out of memory" in str(e):
797
+ logging.warning(
798
+ f"CUDA OOM error during validation step {batch_idx} with shape src: {src.shape}, tgt: {tgt.shape}. Skipping batch."
799
+ )
800
+ gc.collect()
801
+ torch.cuda.empty_cache()
802
+ else:
803
+ logging.error(f"Shapes - src: {src.shape}, tgt: {tgt.shape}")
804
+ # Return None or a placeholder if needed by some aggregation logic,
805
+ # but self.log should handle the metric correctly even if some steps fail.
806
+ return None
807
+
808
+ def configure_optimizers(self):
809
+ optimizer = torch.optim.AdamW(
810
+ self.parameters(), # self.parameters() includes all model parameters
811
+ lr=self.hparams.learning_rate,
812
+ weight_decay=self.hparams.weight_decay,
813
+ )
814
+
815
+ # --- Add Learning Rate Scheduler ---
816
+ # Use linear warmup followed by linear decay (common for transformers)
817
+ # Requires the 'transformers' library: pip install transformers
818
+ try:
819
+ from transformers import get_linear_schedule_with_warmup
820
+
821
+ # Estimate total training steps if trainer is available
822
+ # estimated_stepping_batches gives steps per epoch * num_epochs / num_devices (if using DDP)
823
+ # For total steps across all devices * epochs, we might need to calculate differently or use a fixed large number if estimate isn't ready
824
+ # Let's rely on estimated_stepping_batches, assuming it gives a reasonable estimate of steps the optimizer will take.
825
+ # Note: Accessing self.trainer here might be tricky if it's not fully initialized yet.
826
+ # A safer approach might be to calculate based on dataset size and epochs if possible,
827
+ # or use a very large number for num_training_steps if decay to zero is desired eventually.
828
+ # Let's try accessing trainer, but add a fallback.
829
+ try:
830
+ # This attribute is available after trainer setup, might work here.
831
+ num_training_steps = self.trainer.estimated_stepping_batches
832
+ logging.info(
833
+ f"Estimated stepping batches for LR schedule: {num_training_steps}"
834
+ )
835
+ if num_training_steps is None or num_training_steps <= 0:
836
+ logging.warning(
837
+ "Could not estimate stepping batches, using fallback for LR schedule."
838
+ )
839
+ # Fallback: Calculate based on assumed dataset size / effective batch size * epochs
840
+ # This requires knowing the dataset size, which isn't directly available here.
841
+ # Using a large fixed number as a simpler fallback if decay is desired eventually.
842
+ # Or, calculate based on hparams if dataset size was stored? No.
843
+ # Let's default to a large number if estimate fails.
844
+ num_training_steps = 1_000_000 # Adjust this large number if needed
845
+ except AttributeError:
846
+ logging.warning(
847
+ "self.trainer not available yet in configure_optimizers. Using fallback step count for LR schedule."
848
+ )
849
+ num_training_steps = 1_000_000 # Adjust this large number if needed
850
+
851
+ # Set warmup steps (e.g., 5% of total steps)
852
+ num_warmup_steps = int(0.05 * num_training_steps)
853
+ logging.info(
854
+ f"LR Scheduler: Total steps ~{num_training_steps}, Warmup steps: {num_warmup_steps}"
855
+ )
856
+
857
+ scheduler = get_linear_schedule_with_warmup(
858
+ optimizer,
859
+ num_warmup_steps=num_warmup_steps,
860
+ num_training_steps=num_training_steps,
861
+ )
862
+
863
+ lr_scheduler_config = {
864
+ "scheduler": scheduler,
865
+ "interval": "step", # Call scheduler after each training step
866
+ "frequency": 1,
867
+ "name": "linear_warmup_decay_lr", # Optional: Name for logging
868
+ }
869
+ logging.info("Using Linear Warmup/Decay LR Scheduler.")
870
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
871
+
872
+ except ImportError:
873
+ logging.warning(
874
+ "'transformers' library not found. Cannot create linear warmup scheduler. Using constant LR."
875
+ )
876
+ return optimizer
877
+ except Exception as e:
878
+ logging.error(
879
+ f"Error setting up LR scheduler: {e}. Using constant LR.", exc_info=True
880
+ )
881
+ return optimizer
882
+
883
+
884
+ # --- 6. Inference (Translation) (No changes needed) ---
885
+ # These functions remain largely the same but will take the LightningModule instance
886
+
887
+
888
+ def greedy_decode(
889
+ model: pl.LightningModule, # Takes the LightningModule
890
+ src: torch.Tensor,
891
+ src_padding_mask: torch.Tensor,
892
+ max_len: int,
893
+ sos_idx: int,
894
+ eos_idx: int,
895
+ device: torch.device,
896
+ ) -> torch.Tensor:
897
+ """Performs greedy decoding using the LightningModule's model."""
898
+ # model.eval() # Lightning handles eval mode during inference/testing
899
+ transformer_model = model.model # Access the underlying Seq2SeqTransformer
900
+
901
+ try:
902
+ with torch.no_grad():
903
+ # Use the model's encode/decode methods
904
+ memory = transformer_model.encode(
905
+ src, src_padding_mask
906
+ ) # [1, src_len, emb_size]
907
+ memory = memory.to(device)
908
+ # Ensure memory_key_padding_mask is also on the correct device for decode
909
+ memory_key_padding_mask = src_padding_mask.to(memory.device) # [1, src_len]
910
+
911
+ ys = (
912
+ torch.ones(1, 1).fill_(sos_idx).type(torch.long).to(device)
913
+ ) # [1, 1] (Batch size 1)
914
+
915
+ for i in range(max_len - 1):
916
+ tgt_seq_len = ys.shape[1]
917
+ # Create masks for the current decoded sequence length
918
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to(
919
+ device
920
+ ) # [curr_len, curr_len]
921
+ # No padding in target during greedy decode yet
922
+ tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool).to(
923
+ device
924
+ ) # [1, curr_len]
925
+
926
+ # Use the model's decode method
927
+ out = transformer_model.decode(
928
+ ys, memory, tgt_mask, tgt_padding_mask, memory_key_padding_mask
929
+ )
930
+ # out: [1, curr_len, emb_size]
931
+
932
+ # Get the logits for the last token generated
933
+ last_token_logits = transformer_model.generator(
934
+ out[:, -1, :]
935
+ ) # [1, tgt_vocab_size]
936
+ prob = last_token_logits # Use logits directly for argmax
937
+ _, next_word = torch.max(prob, dim=1)
938
+ next_word = next_word.item()
939
+
940
+ # Append the predicted token ID
941
+ ys = torch.cat(
942
+ [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1
943
+ )
944
+
945
+ # Stop if EOS token is generated
946
+ if next_word == eos_idx:
947
+ break
948
+ # Return the generated sequence, excluding the initial SOS token
949
+ return ys[:, 1:]
950
+
951
+ except RuntimeError as e:
952
+ logging.error(f"Runtime error during greedy decode: {e}")
953
+ if "CUDA out of memory" in str(e):
954
+ gc.collect()
955
+ torch.cuda.empty_cache()
956
+ # Return an empty tensor on error
957
+ return torch.tensor([[]], dtype=torch.long, device=device)
958
+
959
+
960
+ def translate(
961
+ model: pl.LightningModule, # Takes the LightningModule
962
+ src_sentence: str,
963
+ smiles_tokenizer,
964
+ iupac_tokenizer,
965
+ device: torch.device,
966
+ max_len: int,
967
+ sos_idx: int,
968
+ eos_idx: int,
969
+ pad_idx: int,
970
+ ) -> str:
971
+ """Translates a single SMILES string using the LightningModule."""
972
+ model.eval() # Ensure model is in eval mode for inference
973
+
974
+ try:
975
+ src_encoded = smiles_tokenizer.encode(src_sentence)
976
+ if not src_encoded or len(src_encoded.ids) == 0:
977
+ logging.warning(f"Encoding failed for SMILES: {src_sentence}")
978
+ return "[Encoding Error]"
979
+ # Truncate source sequence if needed before creating tensor
980
+ src_ids = src_encoded.ids[:max_len]
981
+ if not src_ids:
982
+ logging.warning(
983
+ f"Source sequence empty after truncation for SMILES: {src_sentence}"
984
+ )
985
+ return "[Encoding Error - Empty Src]"
986
+
987
+ except Exception as e:
988
+ logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}")
989
+ return "[Encoding Error]"
990
+
991
+ # Create tensor and move to device
992
+ src = (
993
+ torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)
994
+ ) # Add batch dimension
995
+ # Create padding mask (boolean, True where padded)
996
+ # For single sentence inference, there's no padding unless the original sequence was shorter than max_len
997
+ # and we padded it, but here we just take the IDs. The mask should reflect the actual length.
998
+ # However, the model expects a mask, even if it's all False for non-padded sequences.
999
+ src_padding_mask = src == pad_idx # [1, src_len]
1000
+
1001
+ # Perform greedy decoding
1002
+ tgt_tokens_tensor = greedy_decode(
1003
+ model=model, # Pass the LightningModule
1004
+ src=src,
1005
+ src_padding_mask=src_padding_mask,
1006
+ max_len=max_len, # Use the configured max_len for generation limit
1007
+ sos_idx=sos_idx,
1008
+ eos_idx=eos_idx,
1009
+ device=device,
1010
+ )
1011
+
1012
+ # Decode the generated token IDs
1013
+ if tgt_tokens_tensor.numel() > 0:
1014
+ tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist()
1015
+ try:
1016
+ # Decode using the target tokenizer, skipping special tokens like <pad>, <sos>, <eos>
1017
+ translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True)
1018
+ return translation
1019
+ except Exception as e:
1020
+ logging.error(f"Error decoding target tokens {tgt_tokens}: {e}")
1021
+ return "[Decoding Error]"
1022
+ else:
1023
+ # Log if decoding returned an empty tensor (might happen on error in greedy_decode)
1024
+ # logging.warning(f"Greedy decode returned empty tensor for SMILES: {src_sentence}")
1025
+ return "[Decoding Error - Empty Output]"
1026
+
1027
+
1028
+ # --- 7. Main Execution Script (Minor updates for clarity) ---
1029
+ if __name__ == "__main__":
1030
+ pl.seed_everything(RANDOM_SEED, workers=True) # Seed everything for reproducibility
1031
+
1032
+ # --- Create Checkpoint Directory ---
1033
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
1034
+
1035
+ # --- Load Data from CSV and Split ---
1036
+ # (Keep this data preparation step outside the Lightning Module)
1037
+ logging.info(f"Loading and splitting data from {INPUT_CSV_FILE}...")
1038
+ # (Re-using the data loading and splitting logic from the original script)
1039
+ try:
1040
+ # Load with dtype specification for potentially large files
1041
+ df = pd.read_csv(INPUT_CSV_FILE, dtype={"SMILES": str, "Systematic": str})
1042
+ logging.info(f"Initial rows loaded: {len(df)}")
1043
+ if "SMILES" not in df.columns:
1044
+ raise ValueError("CSV must contain 'SMILES' column.")
1045
+ if "Systematic" not in df.columns:
1046
+ raise ValueError("CSV must contain 'Systematic' (IUPAC name) column.")
1047
+ df.rename(columns={"Systematic": "IUPAC"}, inplace=True)
1048
+
1049
+ initial_rows = len(df)
1050
+ df.dropna(subset=["SMILES", "IUPAC"], inplace=True)
1051
+ rows_after_na = len(df)
1052
+ if initial_rows > rows_after_na:
1053
+ logging.info(
1054
+ f"Dropped {initial_rows - rows_after_na} rows with missing values."
1055
+ )
1056
+ # Strip whitespace and filter empty strings more efficiently
1057
+ df = df[df["SMILES"].str.strip().astype(bool)]
1058
+ df = df[df["IUPAC"].str.strip().astype(bool)]
1059
+ df["SMILES"] = df["SMILES"].str.strip()
1060
+ df["IUPAC"] = df["IUPAC"].str.strip()
1061
+ rows_after_empty = len(df)
1062
+ if rows_after_na > rows_after_empty:
1063
+ logging.info(
1064
+ f"Dropped {rows_after_na - rows_after_empty} rows with empty strings after stripping."
1065
+ )
1066
+
1067
+ smiles_data = df["SMILES"].tolist()
1068
+ iupac_data = df["IUPAC"].tolist()
1069
+ logging.info(f"Loaded {len(smiles_data)} valid pairs from CSV.")
1070
+ del df
1071
+ gc.collect() # Free memory
1072
+
1073
+ if len(smiles_data) < 10:
1074
+ raise ValueError(
1075
+ f"Not enough valid data ({len(smiles_data)}) for split. Need at least 10."
1076
+ )
1077
+
1078
+ train_smi, val_smi, train_iupac, val_iupac = train_test_split(
1079
+ smiles_data,
1080
+ iupac_data,
1081
+ test_size=VALIDATION_SPLIT,
1082
+ random_state=RANDOM_SEED,
1083
+ )
1084
+ logging.info(f"Split: {len(train_smi)} train, {len(val_smi)} validation.")
1085
+ del smiles_data, iupac_data
1086
+ gc.collect() # Free memory
1087
+
1088
+ logging.info("Writing split data to files...")
1089
+ with open(TRAIN_SMILES_FILE, "w", encoding="utf-8") as f:
1090
+ f.write("\n".join(train_smi))
1091
+ with open(TRAIN_IUPAC_FILE, "w", encoding="utf-8") as f:
1092
+ f.write("\n".join(train_iupac))
1093
+ with open(VAL_SMILES_FILE, "w", encoding="utf-8") as f:
1094
+ f.write("\n".join(val_smi))
1095
+ with open(VAL_IUPAC_FILE, "w", encoding="utf-8") as f:
1096
+ f.write("\n".join(val_iupac))
1097
+ logging.info(
1098
+ f"Split files written: {TRAIN_SMILES_FILE}, {TRAIN_IUPAC_FILE}, {VAL_SMILES_FILE}, {VAL_IUPAC_FILE}"
1099
+ )
1100
+ del train_smi, val_smi, train_iupac, val_iupac
1101
+ gc.collect() # Free memory
1102
+
1103
+ except FileNotFoundError:
1104
+ logging.error(f"Fatal error: Input CSV file not found at {INPUT_CSV_FILE}")
1105
+ exit(1)
1106
+ except ValueError as ve:
1107
+ logging.error(f"Fatal error during data preparation: {ve}")
1108
+ exit(1)
1109
+ except Exception as e:
1110
+ logging.error(f"Fatal error during data preparation: {e}", exc_info=True)
1111
+ exit(1)
1112
+ # --- End Data Preparation ---
1113
+
1114
+ # --- Initialize Tokenizers ---
1115
+ logging.info("Initializing Tokenizers...")
1116
+ # Ensure training files exist before attempting to train tokenizers
1117
+ if not os.path.exists(TRAIN_SMILES_FILE) or not os.path.exists(TRAIN_IUPAC_FILE):
1118
+ logging.error(
1119
+ f"Training files ({TRAIN_SMILES_FILE}, {TRAIN_IUPAC_FILE}) not found. Cannot train tokenizers."
1120
+ )
1121
+ exit(1)
1122
+
1123
+ smiles_tokenizer = get_smiles_tokenizer(
1124
+ train_files=[TRAIN_SMILES_FILE],
1125
+ vocab_size=SRC_VOCAB_SIZE_ESTIMATE,
1126
+ tokenizer_path=SMILES_TOKENIZER_FILE,
1127
+ )
1128
+ iupac_tokenizer = get_iupac_tokenizer(
1129
+ train_files=[TRAIN_IUPAC_FILE],
1130
+ vocab_size=TGT_VOCAB_SIZE_ESTIMATE,
1131
+ tokenizer_path=IUPAC_TOKENIZER_FILE,
1132
+ )
1133
+
1134
+ ACTUAL_SRC_VOCAB_SIZE = smiles_tokenizer.get_vocab_size()
1135
+ ACTUAL_TGT_VOCAB_SIZE = iupac_tokenizer.get_vocab_size()
1136
+ logging.info(f"Actual SMILES Vocab Size: {ACTUAL_SRC_VOCAB_SIZE}")
1137
+ logging.info(f"Actual IUPAC Vocab Size: {ACTUAL_TGT_VOCAB_SIZE}")
1138
+ # Update hparams with actual sizes (will be logged by WandbLogger)
1139
+ hparams["actual_src_vocab_size"] = ACTUAL_SRC_VOCAB_SIZE
1140
+ hparams["actual_tgt_vocab_size"] = ACTUAL_TGT_VOCAB_SIZE
1141
+
1142
+ # --- Setup WandB Logger ---
1143
+ # Ensure WANDB_ENTITY is set if required, otherwise it uses default
1144
+ if WANDB_ENTITY is None:
1145
+ logging.warning(
1146
+ "WANDB_ENTITY not set. WandB will log to your default entity. Set WANDB_ENTITY='your_username_or_team' to specify."
1147
+ )
1148
+
1149
+ wandb_logger = WandbLogger(
1150
+ project=WANDB_PROJECT,
1151
+ entity=WANDB_ENTITY, # Set your entity here or leave as None
1152
+ name=WANDB_RUN_NAME,
1153
+ config=hparams, # Log hyperparameters defined above
1154
+ # log_model='all' # Log model checkpoints to WandB (can consume significant storage)
1155
+ # log_model=True # Log best model checkpoint based on monitor
1156
+ )
1157
+
1158
+ # --- Initialize Datasets and DataLoaders ---
1159
+ logging.info("Creating Datasets and DataLoaders...")
1160
+ try:
1161
+ train_dataset = SmilesIupacDataset(TRAIN_SMILES_FILE, TRAIN_IUPAC_FILE)
1162
+ val_dataset = SmilesIupacDataset(VAL_SMILES_FILE, VAL_IUPAC_FILE)
1163
+ if len(train_dataset) == 0 or len(val_dataset) == 0:
1164
+ logging.error(
1165
+ "Training or validation dataset is empty. Check data splitting and file content."
1166
+ )
1167
+ exit(1)
1168
+ except Exception as e:
1169
+ logging.error(f"Error creating Datasets: {e}", exc_info=True)
1170
+ exit(1)
1171
+
1172
+ # Create partial function for collate_fn to pass tokenizers and params
1173
+ def collate_fn_partial(batch):
1174
+ return collate_fn(
1175
+ batch,
1176
+ smiles_tokenizer,
1177
+ iupac_tokenizer,
1178
+ PAD_IDX,
1179
+ SOS_IDX,
1180
+ EOS_IDX,
1181
+ hparams["max_len"],
1182
+ )
1183
+
1184
+ # Use persistent_workers=True if num_workers > 0 for efficiency, especially with DDP
1185
+ persistent_workers = NUM_WORKERS > 0 and STRATEGY == "ddp" # Recommended for DDP
1186
+
1187
+ train_dataloader = DataLoader(
1188
+ train_dataset,
1189
+ batch_size=BATCH_SIZE_PER_GPU,
1190
+ shuffle=True,
1191
+ collate_fn=collate_fn_partial,
1192
+ num_workers=NUM_WORKERS,
1193
+ pin_memory=True,
1194
+ persistent_workers=persistent_workers,
1195
+ drop_last=True,
1196
+ ) # Drop last incomplete batch in training for DDP consistency
1197
+ val_dataloader = DataLoader(
1198
+ val_dataset,
1199
+ batch_size=BATCH_SIZE_PER_GPU, # Use same batch size for validation
1200
+ shuffle=False,
1201
+ collate_fn=collate_fn_partial,
1202
+ num_workers=NUM_WORKERS,
1203
+ pin_memory=True,
1204
+ persistent_workers=persistent_workers,
1205
+ drop_last=False,
1206
+ ) # Keep all validation batches
1207
+
1208
+ # --- Initialize Model ---
1209
+ logging.info("Initializing Lightning Module...")
1210
+ # Pass hparams dictionary directly, PTL handles it via save_hyperparameters
1211
+ model = SmilesIupacLitModule(
1212
+ src_vocab_size=ACTUAL_SRC_VOCAB_SIZE,
1213
+ tgt_vocab_size=ACTUAL_TGT_VOCAB_SIZE,
1214
+ hparams_dict=hparams,
1215
+ )
1216
+
1217
+ # Optional: Log model topology to WandB (do this after model init, before training)
1218
+ # Note: watch can sometimes slow down training start, especially with large models
1219
+ # wandb_logger.watch(model, log='all', log_freq=100) # Log gradients and parameters
1220
+
1221
+ # --- Define Callbacks ---
1222
+ checkpoint_callback = ModelCheckpoint(
1223
+ dirpath=CHECKPOINT_DIR,
1224
+ filename=BEST_MODEL_FILENAME + "-{epoch:02d}-{val_loss:.4f}",
1225
+ save_top_k=1, # Save only the best model
1226
+ verbose=True,
1227
+ monitor="val_loss", # Monitor validation loss
1228
+ mode="min", # Save the model with the minimum validation loss
1229
+ save_last=True, # Optionally save the last checkpoint as well
1230
+ )
1231
+ early_stopping_callback = EarlyStopping(
1232
+ monitor="val_loss",
1233
+ patience=PATIENCE, # Number of epochs with no improvement after which training will be stopped
1234
+ verbose=True,
1235
+ mode="min",
1236
+ )
1237
+
1238
+ # --- Initialize PyTorch Lightning Trainer ---
1239
+ logging.info(
1240
+ f"Initializing PyTorch Lightning Trainer (GPUs={DEVICES}, Strategy='{STRATEGY}', Precision='{PRECISION}')..."
1241
+ )
1242
+ trainer = pl.Trainer(
1243
+ accelerator=ACCELERATOR,
1244
+ devices=DEVICES,
1245
+ strategy=STRATEGY,
1246
+ precision=PRECISION,
1247
+ max_epochs=NUM_EPOCHS,
1248
+ logger=wandb_logger, # Use WandbLogger
1249
+ callbacks=[checkpoint_callback, early_stopping_callback],
1250
+ gradient_clip_val=GRAD_CLIP_NORM, # Gradient clipping
1251
+ accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES, # Gradient accumulation
1252
+ log_every_n_steps=50, # How often to log metrics (steps across all GPUs)
1253
+ # deterministic=True, # Might slow down training, use for debugging reproducibility if needed
1254
+ # profiler="simple", # Optional: Add profiler ("simple", "advanced", "pytorch") for performance analysis
1255
+ # Checkpointing behavior is controlled by ModelCheckpoint callback
1256
+ # enable_checkpointing=True, # Default is True if callbacks has ModelCheckpoint
1257
+ )
1258
+
1259
+ # --- Start Training ---
1260
+ logging.info(
1261
+ f"Starting training with Effective Batch Size: {hparams['effective_batch_size']}..."
1262
+ )
1263
+ start_time = time.time()
1264
+ try:
1265
+ trainer.fit(model, train_dataloader, val_dataloader)
1266
+ training_duration = time.time() - start_time
1267
+ logging.info(
1268
+ f"Training finished in {training_duration / 3600:.2f} hours ({training_duration:.2f} seconds)."
1269
+ )
1270
+
1271
+ # Log best model path and score
1272
+ best_path = checkpoint_callback.best_model_path
1273
+ best_score = checkpoint_callback.best_model_score # This is a tensor, get value
1274
+ if best_score is not None:
1275
+ logging.info(
1276
+ f"Best model checkpoint saved at: {best_path} with val_loss: {best_score.item():.4f}"
1277
+ )
1278
+ # Log best score to wandb summary
1279
+ wandb_logger.experiment.summary["best_val_loss"] = best_score.item()
1280
+ wandb_logger.experiment.summary["best_model_path"] = best_path
1281
+ else:
1282
+ logging.warning(
1283
+ "Could not retrieve best model score from checkpoint callback."
1284
+ )
1285
+
1286
+ except Exception as e:
1287
+ logging.error(f"Fatal error during training: {e}", exc_info=True)
1288
+ # Ensure wandb run is finished even on error
1289
+ if wandb.run is not None:
1290
+ wandb.finish(exit_code=1) # Mark as failed run
1291
+ exit(1)
1292
+
1293
+ # --- Load Best Model for Final Translation Examples ---
1294
+ best_model_path_to_load = checkpoint_callback.best_model_path
1295
+ logging.info(
1296
+ f"\nLoading best model from {best_model_path_to_load} for translation examples..."
1297
+ )
1298
+ final_model = None
1299
+ if best_model_path_to_load and os.path.exists(best_model_path_to_load):
1300
+ try:
1301
+ # Load the model using the Lightning checkpoint loading mechanism
1302
+ # Pass hparams_dict again in case it's needed and not perfectly saved/loaded
1303
+ final_model = SmilesIupacLitModule.load_from_checkpoint(
1304
+ best_model_path_to_load,
1305
+ # Provide necessary args again if they weren't saved in hparams properly
1306
+ # (though save_hyperparameters should handle this)
1307
+ src_vocab_size=ACTUAL_SRC_VOCAB_SIZE,
1308
+ tgt_vocab_size=ACTUAL_TGT_VOCAB_SIZE,
1309
+ hparams_dict=hparams, # Pass the original hparams
1310
+ )
1311
+ # Determine device for inference (use the first GPU if available)
1312
+ inference_device = torch.device(
1313
+ f"{ACCELERATOR}:0"
1314
+ if ACCELERATOR == "gpu" and torch.cuda.is_available()
1315
+ else "cpu"
1316
+ )
1317
+ final_model = final_model.to(inference_device)
1318
+ final_model.eval() # Set to evaluation mode
1319
+ final_model.freeze() # Freeze weights for inference
1320
+ logging.info(
1321
+ f"Best model loaded successfully to {inference_device} for final translation."
1322
+ )
1323
+ except Exception as e:
1324
+ logging.error(
1325
+ f"Error loading saved model from {best_model_path_to_load}: {e}",
1326
+ exc_info=True,
1327
+ )
1328
+ final_model = None # Ensure final_model is None if loading fails
1329
+ else:
1330
+ logging.error(
1331
+ f"Error: Best model checkpoint path not found or invalid: '{best_model_path_to_load}'. Cannot perform final translation."
1332
+ )
1333
+
1334
+ # --- Example Translation (using some validation samples) ---
1335
+ if final_model:
1336
+ logging.info("\n--- Example Translations (using validation data) ---")
1337
+ num_examples = 20 # Show more examples
1338
+ try:
1339
+ # Load validation samples directly from the files
1340
+ val_smi_examples = []
1341
+ val_iupac_examples = []
1342
+ if os.path.exists(VAL_SMILES_FILE) and os.path.exists(VAL_IUPAC_FILE):
1343
+ with (
1344
+ open(VAL_SMILES_FILE, "r", encoding="utf-8") as f_smi,
1345
+ open(VAL_IUPAC_FILE, "r", encoding="utf-8") as f_iupac,
1346
+ ):
1347
+ for i, (smi_line, iupac_line) in enumerate(zip(f_smi, f_iupac)):
1348
+ if i >= num_examples:
1349
+ break
1350
+ val_smi_examples.append(smi_line.strip())
1351
+ val_iupac_examples.append(iupac_line.strip())
1352
+ else:
1353
+ logging.warning(
1354
+ f"Validation files ({VAL_SMILES_FILE}, {VAL_IUPAC_FILE}) not found. Cannot show examples."
1355
+ )
1356
+
1357
+ if len(val_smi_examples) > 0:
1358
+ print("\n" + "=" * 40)
1359
+ print(
1360
+ f"Example Translations (First {len(val_smi_examples)} Validation Samples)"
1361
+ )
1362
+ print("=" * 40)
1363
+ # Use the device the model was loaded onto
1364
+ inference_device = next(final_model.parameters()).device
1365
+ translation_examples = [] # For potential logging to wandb
1366
+ for i in range(len(val_smi_examples)):
1367
+ smi = val_smi_examples[i]
1368
+ true_iupac = val_iupac_examples[i]
1369
+ predicted_iupac = translate(
1370
+ model=final_model, # Use the loaded best model
1371
+ src_sentence=smi,
1372
+ smiles_tokenizer=smiles_tokenizer,
1373
+ iupac_tokenizer=iupac_tokenizer,
1374
+ device=inference_device, # Use model's device
1375
+ max_len=hparams["max_len"],
1376
+ sos_idx=SOS_IDX,
1377
+ eos_idx=EOS_IDX,
1378
+ pad_idx=PAD_IDX,
1379
+ )
1380
+ print(f"\nExample {i + 1}:")
1381
+ print(f" SMILES: {smi}")
1382
+ print(f" True IUPAC: {true_iupac}")
1383
+ print(f" Predicted IUPAC: {predicted_iupac}")
1384
+ print("-" * 30)
1385
+ # Prepare data for wandb table
1386
+ translation_examples.append([smi, true_iupac, predicted_iupac])
1387
+
1388
+ print("=" * 40 + "\n")
1389
+
1390
+ # Log examples to a WandB Table
1391
+ try:
1392
+ columns = ["SMILES", "True IUPAC", "Predicted IUPAC"]
1393
+ wandb_table = wandb.Table(
1394
+ data=translation_examples, columns=columns
1395
+ )
1396
+ wandb_logger.experiment.log(
1397
+ {"validation_translations": wandb_table}
1398
+ )
1399
+ logging.info("Logged translation examples to WandB Table.")
1400
+ except Exception as wb_err:
1401
+ logging.error(
1402
+ f"Failed to log translation examples to WandB: {wb_err}"
1403
+ )
1404
+
1405
+ else:
1406
+ logging.warning("Could not load validation samples for examples.")
1407
+ except Exception as e:
1408
+ logging.error(f"Error during example translation phase: {e}", exc_info=True)
1409
+ else:
1410
+ logging.warning(
1411
+ "Skipping final translation examples as the best model could not be loaded."
1412
+ )
1413
+
1414
+ # --- Finish WandB Run ---
1415
+ if wandb.run is not None:
1416
+ wandb.finish()
1417
+ logging.info("WandB run finished.")
1418
+ else:
1419
+ logging.info("No active WandB run to finish.")
1420
+
1421
+ logging.info("Script finished.")
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ pytorch_lightning
3
+ json
4
+ logging
5
+ tokenizers
6
+ transformers
7
+ math
8
+ gc
9
+ os
10
+ gradio
11
+ huggingface_hub
12
+ pandas
13
+ sklearn
14
+ scikit-learn
15
+ time
16
+ wandb