transpolymer commited on
Commit
eea9e94
·
verified ·
1 Parent(s): dd5919f

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +30 -26
prediction.py CHANGED
@@ -11,12 +11,12 @@ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
11
  embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
 
13
  # Load saved scalers (for inverse_transform)
14
- scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib") # Scaler for Tensile Strength
15
- scaler_ionization_energy = joblib.load("scaler_Ionization_Energy_eV_.joblib") # Scaler for Ionization Energy
16
- scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib") # Scaler for Electron Affinity
17
- scaler_logp = joblib.load("scaler_LogP.joblib") # Scaler for LogP
18
- scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib") # Scaler for Refractive Index
19
- scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib") # Scaler for Molecular Weight
20
 
21
  # Descriptor function with exact order from training
22
  def compute_descriptors(smiles):
@@ -39,27 +39,33 @@ def compute_descriptors(smiles):
39
 
40
  # Define your model class exactly like in training
41
  class TransformerRegressor(nn.Module):
42
- def __init__(self, input_dim=768, descriptor_dim=10, d_model=768, nhead=4, num_layers=2, num_targets=6):
43
  super().__init__()
44
- self.descriptor_proj = nn.Linear(descriptor_dim, d_model)
45
- encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
46
- self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
47
- self.regressor = nn.Sequential(
48
- nn.Flatten(),
49
- nn.Linear(2 * d_model, 256),
50
  nn.ReLU(),
51
- nn.Linear(256, num_targets)
 
 
52
  )
53
 
54
- def forward(self, embedding, descriptors):
55
- desc_proj = self.descriptor_proj(descriptors).unsqueeze(1) # (B, 1, d_model)
56
- stacked = torch.cat([embedding.unsqueeze(1), desc_proj], dim=1) # (B, 2, d_model)
57
- encoded = self.transformer(stacked) # (B, 2, d_model)
58
- output = self.regressor(encoded)
59
- return output
 
 
 
 
 
60
 
61
  # Load model
62
- model = TransformerRegressor()
63
  model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
64
  model.eval()
65
 
@@ -67,19 +73,18 @@ model.eval()
67
  def predict_properties(smiles):
68
  try:
69
  descriptors = compute_descriptors(smiles)
70
- descriptors_tensor = torch.tensor(descriptors).unsqueeze(0) # (1, 10)
71
 
72
  # Get embedding
73
  inputs = tokenizer(smiles, return_tensors="pt")
74
  with torch.no_grad():
75
  outputs = embedding_model(**inputs)
76
- emb = outputs.last_hidden_state[:, 0, :] # [CLS] token, shape (1, 768)
77
 
78
  # Forward pass
79
  with torch.no_grad():
80
- preds = model(emb, descriptors_tensor)
81
 
82
- # Inverse transform predictions using respective scalers
83
  preds_np = preds.numpy()
84
  preds_rescaled = np.concatenate([
85
  scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
@@ -90,7 +95,6 @@ def predict_properties(smiles):
90
  scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
91
  ], axis=1)
92
 
93
- # Round and format
94
  keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
95
  results = dict(zip(keys, preds_rescaled.flatten().round(4)))
96
 
 
11
  embedding_model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
12
 
13
  # Load saved scalers (for inverse_transform)
14
+ scaler_tensile_strength = joblib.load("scaler_Tensile_strength_Mpa_.joblib")
15
+ scaler_ionization_energy = joblib.load("scaler_Ionization_Energy_eV_.joblib")
16
+ scaler_electron_affinity = joblib.load("scaler_Electron_Affinity_eV_.joblib")
17
+ scaler_logp = joblib.load("scaler_LogP.joblib")
18
+ scaler_refractive_index = joblib.load("scaler_Refractive_Index.joblib")
19
+ scaler_molecular_weight = joblib.load("scaler_Molecular_Weight_g_mol_.joblib")
20
 
21
  # Descriptor function with exact order from training
22
  def compute_descriptors(smiles):
 
39
 
40
  # Define your model class exactly like in training
41
  class TransformerRegressor(nn.Module):
42
+ def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
43
  super().__init__()
44
+ self.feat_proj = nn.Linear(input_dim, hidden_dim)
45
+ encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
46
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
47
+ self.regression_head = nn.Sequential(
48
+ nn.Linear(hidden_dim, 128),
 
49
  nn.ReLU(),
50
+ nn.Linear(128, 64),
51
+ nn.ReLU(),
52
+ nn.Linear(64, output_dim)
53
  )
54
 
55
+ def forward(self, x):
56
+ x = self.feat_proj(x)
57
+ x = self.transformer_encoder(x)
58
+ x = x.mean(dim=1)
59
+ return self.regression_head(x)
60
+
61
+ # Set model hyperparameters (must match training config)
62
+ input_dim = 768 # ChemBERTa embedding size
63
+ hidden_dim = 256
64
+ num_layers = 2
65
+ output_dim = 6 # Number of properties predicted
66
 
67
  # Load model
68
+ model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
69
  model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device("cpu")))
70
  model.eval()
71
 
 
73
  def predict_properties(smiles):
74
  try:
75
  descriptors = compute_descriptors(smiles)
76
+ descriptors_tensor = torch.tensor(descriptors).unsqueeze(0)
77
 
78
  # Get embedding
79
  inputs = tokenizer(smiles, return_tensors="pt")
80
  with torch.no_grad():
81
  outputs = embedding_model(**inputs)
82
+ emb = outputs.last_hidden_state[:, 0, :] # CLS token output (1, 768)
83
 
84
  # Forward pass
85
  with torch.no_grad():
86
+ preds = model(emb)
87
 
 
88
  preds_np = preds.numpy()
89
  preds_rescaled = np.concatenate([
90
  scaler_tensile_strength.inverse_transform(preds_np[:, [0]]),
 
95
  scaler_molecular_weight.inverse_transform(preds_np[:, [5]])
96
  ], axis=1)
97
 
 
98
  keys = ["Tensile Strength", "Ionization Energy", "Electron Affinity", "logP", "Refractive Index", "Molecular Weight"]
99
  results = dict(zip(keys, preds_rescaled.flatten().round(4)))
100