transpolymer commited on
Commit
b73e3e2
·
verified ·
1 Parent(s): eea9e94

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +41 -39
prediction.py CHANGED
@@ -1,30 +1,32 @@
1
  import torch
2
  import torch.nn as nn
3
  import joblib
 
4
  from rdkit import Chem
5
  from rdkit.Chem import Descriptors
6
  from transformers import AutoTokenizer, AutoModel
7
- import numpy as np
8
 
9
- # Load tokenizer and model for embeddings
10
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
11
  embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
 
13
- # Load saved scalers (for inverse_transform)
14
- scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib")
15
- scaler_ionization_energy = joblib.load("scaler_Ionization_Energy_eV_.joblib")
16
- scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib")
17
- scaler_logp = joblib.load("scaler_LogP.joblib")
18
- scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib")
19
- scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
 
 
20
 
21
- # Descriptor function with exact order from training
22
- def compute_descriptors(smiles):
23
  mol = Chem.MolFromSmiles(smiles)
24
  if mol is None:
25
- raise ValueError("Invalid SMILES")
26
-
27
- return np.array([
28
  Descriptors.MolWt(mol),
29
  Descriptors.MolLogP(mol),
30
  Descriptors.TPSA(mol),
@@ -35,9 +37,10 @@ def compute_descriptors(smiles):
35
  Descriptors.HeavyAtomCount(mol),
36
  Descriptors.RingCount(mol),
37
  Descriptors.MolMR(mol)
38
- ], dtype=np.float32)
 
39
 
40
- # Define your model class exactly like in training
41
  class TransformerRegressor(nn.Module):
42
  def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
43
  super().__init__()
@@ -55,49 +58,48 @@ class TransformerRegressor(nn.Module):
55
  def forward(self, x):
56
  x = self.feat_proj(x)
57
  x = self.transformer_encoder(x)
58
- x = x.mean(dim=1)
59
  return self.regression_head(x)
60
 
61
- # Set model hyperparameters (must match training config)
62
- input_dim = 768 # ChemBERTa embedding size
63
  hidden_dim = 256
64
  num_layers = 2
65
- output_dim = 6 # Number of properties predicted
66
 
67
- # Load model
 
68
  model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
69
- model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
70
  model.eval()
71
 
72
- # Main prediction function
73
- def predict_properties(smiles):
74
  try:
75
- descriptors = compute_descriptors(smiles)
76
- descriptors_tensor = torch.tensor(descriptors).unsqueeze(0)
77
 
78
- # Get embedding
79
  inputs = tokenizer(smiles, return_tensors="pt")
80
  with torch.no_grad():
81
  outputs = embedding_model(**inputs)
82
- emb = outputs.last_hidden_state[:, 0, :] # CLS token output (1, 768)
83
 
84
- # Forward pass
85
  with torch.no_grad():
86
- preds = model(emb)
87
 
88
  preds_np = preds.numpy()
 
 
 
89
  preds_rescaled = np.concatenate([
90
- scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
91
- scaler_ionization_energy.inverse_transform(preds_np[:, [1]]),
92
- scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
93
- scaler_logp.inverse_transform(preds_np[:, [3]]),
94
- scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
95
- scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
96
  ], axis=1)
97
 
98
- keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
99
- results = dict(zip(keys, preds_rescaled.flatten().round(4)))
100
-
101
  return results
102
 
103
  except Exception as e:
 
1
  import torch
2
  import torch.nn as nn
3
  import joblib
4
+ import numpy as np
5
  from rdkit import Chem
6
  from rdkit.Chem import Descriptors
7
  from transformers import AutoTokenizer, AutoModel
 
8
 
9
+ # Load ChemBERTa tokenizer and model
10
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
11
  embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
 
13
+ # Load saved scalers for inverse transformations
14
+ scalers = {
15
+ "Tensile Strength": joblib.load("scaler_Tensile_strength_Mpa_.joblib"),
16
+ "Ionization Energy": joblib.load("scaler_Ionization_Energy_eV_.joblib"),
17
+ "Electron Affinity": joblib.load("scaler_Electron_Affinity_eV_.joblib"),
18
+ "logP": joblib.load("scaler_LogP.joblib"),
19
+ "Refractive Index": joblib.load("scaler_Refractive_Index.joblib"),
20
+ "Molecular Weight": joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
21
+ }
22
 
23
+ # Descriptor calculation
24
+ def compute_descriptors(smiles: str):
25
  mol = Chem.MolFromSmiles(smiles)
26
  if mol is None:
27
+ raise ValueError("Invalid SMILES string.")
28
+
29
+ descriptors = [
30
  Descriptors.MolWt(mol),
31
  Descriptors.MolLogP(mol),
32
  Descriptors.TPSA(mol),
 
37
  Descriptors.HeavyAtomCount(mol),
38
  Descriptors.RingCount(mol),
39
  Descriptors.MolMR(mol)
40
+ ]
41
+ return np.array(descriptors, dtype=np.float32)
42
 
43
+ # Transformer regression model definition
44
  class TransformerRegressor(nn.Module):
45
  def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
46
  super().__init__()
 
58
  def forward(self, x):
59
  x = self.feat_proj(x)
60
  x = self.transformer_encoder(x)
61
+ x = x.mean(dim=1) # Global average pooling
62
  return self.regression_head(x)
63
 
64
+ # Model hyperparameters (must match training)
65
+ input_dim = 768
66
  hidden_dim = 256
67
  num_layers = 2
68
+ output_dim = 6
69
 
70
+ # Load trained model
71
+ device = torch.device("cpu")
72
  model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
73
+ model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
74
  model.eval()
75
 
76
+ # Prediction function
77
+ def predict_properties(smiles: str):
78
  try:
79
+ # Validate SMILES and compute descriptors
80
+ _ = compute_descriptors(smiles)
81
 
82
+ # ChemBERTa embedding (CLS token)
83
  inputs = tokenizer(smiles, return_tensors="pt")
84
  with torch.no_grad():
85
  outputs = embedding_model(**inputs)
86
+ embedding = outputs.last_hidden_state[:, 0, :] # Shape: (1, 768)
87
 
88
+ # Forward pass through model
89
  with torch.no_grad():
90
+ preds = model(embedding)
91
 
92
  preds_np = preds.numpy()
93
+
94
+ # Inverse transform each property
95
+ keys = list(scalers.keys())
96
  preds_rescaled = np.concatenate([
97
+ scalers[keys[i]].inverse_transform(preds_np[:, [i]])
98
+ for i in range(output_dim)
 
 
 
 
99
  ], axis=1)
100
 
101
+ # Create dictionary of results
102
+ results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
 
103
  return results
104
 
105
  except Exception as e: