transpolymer commited on
Commit
3de6f45
·
verified ·
1 Parent(s): 8ef3b45

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +30 -19
prediction.py CHANGED
@@ -40,54 +40,66 @@ def compute_descriptors(smiles: str):
40
  ]
41
  return np.array(descriptors, dtype=np.float32)
42
 
43
- # Transformer regression model definition
44
  class TransformerRegressor(nn.Module):
45
- def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
46
  super().__init__()
47
- self.feat_proj = nn.Linear(input_dim, hidden_dim)
48
- encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
 
 
 
 
 
 
49
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
50
  self.regression_head = nn.Sequential(
51
- nn.Linear(hidden_dim, 128),
52
  nn.ReLU(),
53
- nn.Linear(128, 64),
54
  nn.ReLU(),
55
- nn.Linear(64, output_dim)
56
  )
57
 
58
  def forward(self, x):
59
  x = self.feat_proj(x)
60
  x = self.transformer_encoder(x)
61
- x = x.mean(dim=1) # Global average pooling
62
  return self.regression_head(x)
63
 
64
  # Model hyperparameters (must match training)
65
- input_dim = 768 # Output size of ChemBERTa model
66
- hidden_dim = 256
 
 
67
  num_layers = 2
68
- output_dim = 6 # Number of properties
69
 
70
  # Load trained model
71
  device = torch.device("cpu")
72
- model = TransformerRegressor(input_dim, hidden_dim, num_layers, output_dim)
73
  model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
74
  model.eval()
75
 
76
  # Prediction function
77
  def predict_properties(smiles: str):
78
  try:
79
- # Validate SMILES and compute descriptors
80
- _ = compute_descriptors(smiles)
 
81
 
82
- # ChemBERTa embedding (CLS token)
83
  inputs = tokenizer(smiles, return_tensors="pt")
84
  with torch.no_grad():
85
  outputs = embedding_model(**inputs)
86
- embedding = outputs.last_hidden_state[:, 0, :] # Extracting the [CLS] token (shape: (1, 768))
 
 
 
87
 
88
- # Forward pass through model
89
  with torch.no_grad():
90
- preds = model(embedding)
91
 
92
  preds_np = preds.numpy()
93
 
@@ -98,7 +110,6 @@ def predict_properties(smiles: str):
98
  for i in range(output_dim)
99
  ], axis=1)
100
 
101
- # Create dictionary of results
102
  results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
103
  return results
104
 
 
40
  ]
41
  return np.array(descriptors, dtype=np.float32)
42
 
43
+ # Transformer regression model definition (must match training)
44
  class TransformerRegressor(nn.Module):
45
+ def __init__(self, input_dim, embedding_dim, ff_dim, num_layers, output_dim):
46
  super().__init__()
47
+ self.feat_proj = nn.Linear(input_dim, embedding_dim)
48
+ encoder_layer = nn.TransformerEncoderLayer(
49
+ d_model=embedding_dim,
50
+ nhead=8,
51
+ dim_feedforward=ff_dim,
52
+ dropout=0.1,
53
+ batch_first=True
54
+ )
55
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
56
  self.regression_head = nn.Sequential(
57
+ nn.Linear(embedding_dim, 256),
58
  nn.ReLU(),
59
+ nn.Linear(256, 128),
60
  nn.ReLU(),
61
+ nn.Linear(128, output_dim)
62
  )
63
 
64
  def forward(self, x):
65
  x = self.feat_proj(x)
66
  x = self.transformer_encoder(x)
67
+ x = x.mean(dim=1)
68
  return self.regression_head(x)
69
 
70
  # Model hyperparameters (must match training)
71
+ embedding_dim = 768
72
+ descriptor_dim = 1290 # Based on earlier errors. If unsure, use 1290
73
+ input_dim = embedding_dim + descriptor_dim # 768 + 1290 = 2058
74
+ ff_dim = 1024
75
  num_layers = 2
76
+ output_dim = 6
77
 
78
  # Load trained model
79
  device = torch.device("cpu")
80
+ model = TransformerRegressor(input_dim, embedding_dim, ff_dim, num_layers, output_dim)
81
  model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
82
  model.eval()
83
 
84
  # Prediction function
85
  def predict_properties(smiles: str):
86
  try:
87
+ # Compute descriptors
88
+ descriptors = compute_descriptors(smiles)
89
+ descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
90
 
91
+ # Get ChemBERTa embedding (CLS token)
92
  inputs = tokenizer(smiles, return_tensors="pt")
93
  with torch.no_grad():
94
  outputs = embedding_model(**inputs)
95
+ embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
96
+
97
+ # Combine features
98
+ combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1) # Shape: (1, 1, 2058)
99
 
100
+ # Forward pass
101
  with torch.no_grad():
102
+ preds = model(combined)
103
 
104
  preds_np = preds.numpy()
105
 
 
110
  for i in range(output_dim)
111
  ], axis=1)
112
 
 
113
  results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
114
  return results
115