transpolymer commited on
Commit
15f5470
·
verified ·
1 Parent(s): 9202635

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +32 -29
prediction.py CHANGED
@@ -6,26 +6,24 @@ from rdkit.Chem import Descriptors
6
  from transformers import AutoTokenizer, AutoModel
7
  import numpy as np
8
 
9
- # Load tokenizer and embedding model
10
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
11
  embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
 
13
- # Load individual scalers
14
- target_keys = [
15
- "Tensile_strength(Mpa)",
16
- "Ionization_Energy(eV)",
17
- "Electron_Affinity(eV)",
18
- "LogP",
19
- "Refractive_Index",
20
- "Molecular_Weight(g/mol)"
21
- ]
22
- scalers = [joblib.load(f"scaler_{key.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('__', '_')}.joblib") for key in target_keys]
23
 
24
- # Descriptor function (must match training order)
25
  def compute_descriptors(smiles):
26
  mol = Chem.MolFromSmiles(smiles)
27
  if mol is None:
28
- raise ValueError("Invalid SMILES string.")
 
29
  return np.array([
30
  Descriptors.MolWt(mol),
31
  Descriptors.MolLogP(mol),
@@ -39,7 +37,7 @@ def compute_descriptors(smiles):
39
  Descriptors.MolMR(mol)
40
  ], dtype=np.float32)
41
 
42
- # Model class must match training
43
  class TransformerRegressor(nn.Module):
44
  def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
45
  super().__init__()
@@ -57,9 +55,10 @@ class TransformerRegressor(nn.Module):
57
  desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
58
  stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
59
  encoded = self.transformer(stacked) # (B, 2, d_model)
60
- return self.regressor(encoded)
 
61
 
62
- # Load trained model
63
  model = TransformerRegressor()
64
  model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
65
  model.eval()
@@ -67,29 +66,33 @@ model.eval()
67
  # Main prediction function
68
  def predict_properties(smiles):
69
  try:
70
- # Compute descriptors
71
  descriptors = compute_descriptors(smiles)
72
  descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
73
 
74
- # Get embedding from ChemBERTa
75
  inputs = tokenizer(smiles, return_tensors="pt")
76
  with torch.no_grad():
77
  outputs = embedding_model(**inputs)
78
- embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
79
 
80
- # Predict
81
  with torch.no_grad():
82
- preds = model(embedding, descriptors_tensor)
83
 
84
- # Inverse transform each prediction
85
- preds_np = preds.numpy().flatten()
86
- preds_rescaled = [
87
- scalers[i].inverse_transform([[preds_np[i]]])[0][0] for i in range(len(scalers))
88
- ]
 
 
 
 
 
89
 
90
- # Prepare results
91
- readable_keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
92
- results = dict(zip(readable_keys, np.round(preds_rescaled, 4)))
93
 
94
  return results
95
 
 
6
  from transformers import AutoTokenizer, AutoModel
7
  import numpy as np
8
 
9
+ # Load tokenizer and model for embeddings
10
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
11
  embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
 
13
+ # Load saved scalers (for inverse_transform)
14
+ scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib") # Scaler for Tensile Strength
15
+ scaler_ionization_energy = joblib.load("scaler_lonization_Energy_eV_.joblib") # Scaler for Ionization Energy
16
+ scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib") # Scaler for Electron Affinity
17
+ scaler_logp = joblib.load("scaler_LogP.joblib") # Scaler for LogP
18
+ scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib") # Scaler for Refractive Index
19
+ scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib") # Scaler for Molecular Weight
 
 
 
20
 
21
+ # Descriptor function with exact order from training
22
  def compute_descriptors(smiles):
23
  mol = Chem.MolFromSmiles(smiles)
24
  if mol is None:
25
+ raise ValueError("Invalid SMILES")
26
+
27
  return np.array([
28
  Descriptors.MolWt(mol),
29
  Descriptors.MolLogP(mol),
 
37
  Descriptors.MolMR(mol)
38
  ], dtype=np.float32)
39
 
40
+ # Define your model class exactly like in training
41
  class TransformerRegressor(nn.Module):
42
  def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
43
  super().__init__()
 
55
  desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
56
  stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
57
  encoded = self.transformer(stacked) # (B, 2, d_model)
58
+ output = self.regressor(encoded)
59
+ return output
60
 
61
+ # Load model
62
  model = TransformerRegressor()
63
  model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
64
  model.eval()
 
66
  # Main prediction function
67
  def predict_properties(smiles):
68
  try:
 
69
  descriptors = compute_descriptors(smiles)
70
  descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
71
 
72
+ # Get embedding
73
  inputs = tokenizer(smiles, return_tensors="pt")
74
  with torch.no_grad():
75
  outputs = embedding_model(**inputs)
76
+ emb = outputs.last_hidden_state[:, 0, :] # [CLS] token, shape (1, 768)
77
 
78
+ # Forward pass
79
  with torch.no_grad():
80
+ preds = model(emb, descriptors_tensor)
81
 
82
+ # Inverse transform predictions using respective scalers
83
+ preds_np = preds.numpy()
84
+ preds_rescaled = np.concatenate([
85
+ scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
86
+ scaler_ionization_energy.inverse_transform(preds_np[:, [1]]),
87
+ scaler_electron_affinity.inverse_transform(preds_np[:, [2]]),
88
+ scaler_logp.inverse_transform(preds_np[:, [3]]),
89
+ scaler_refractive_index.inverse_transform(preds_np[:, [4]]),
90
+ scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
91
+ ], axis=1)
92
 
93
+ # Round and format
94
+ keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
95
+ results = dict(zip(keys, preds_rescaled.flatten().round(4)))
96
 
97
  return results
98