Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +30 -19
prediction.py
CHANGED
@@ -40,54 +40,66 @@ def compute_descriptors(smiles: str):
|
|
40 |
]
|
41 |
return np.array(descriptors, dtype=np.float32)
|
42 |
|
43 |
-
# Transformer regression model definition
|
44 |
class TransformerRegressor(nn.Module):
|
45 |
-
def __init__(self, input_dim,
|
46 |
super().__init__()
|
47 |
-
self.feat_proj = nn.Linear(input_dim,
|
48 |
-
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
50 |
self.regression_head = nn.Sequential(
|
51 |
-
nn.Linear(
|
52 |
nn.ReLU(),
|
53 |
-
nn.Linear(
|
54 |
nn.ReLU(),
|
55 |
-
nn.Linear(
|
56 |
)
|
57 |
|
58 |
def forward(self, x):
|
59 |
x = self.feat_proj(x)
|
60 |
x = self.transformer_encoder(x)
|
61 |
-
x = x.mean(dim=1)
|
62 |
return self.regression_head(x)
|
63 |
|
64 |
# Model hyperparameters (must match training)
|
65 |
-
|
66 |
-
|
|
|
|
|
67 |
num_layers = 2
|
68 |
-
output_dim = 6
|
69 |
|
70 |
# Load trained model
|
71 |
device = torch.device("cpu")
|
72 |
-
model = TransformerRegressor(input_dim,
|
73 |
model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
|
74 |
model.eval()
|
75 |
|
76 |
# Prediction function
|
77 |
def predict_properties(smiles: str):
|
78 |
try:
|
79 |
-
#
|
80 |
-
|
|
|
81 |
|
82 |
-
# ChemBERTa embedding (CLS token)
|
83 |
inputs = tokenizer(smiles, return_tensors="pt")
|
84 |
with torch.no_grad():
|
85 |
outputs = embedding_model(**inputs)
|
86 |
-
embedding = outputs.last_hidden_state[:, 0, :] #
|
|
|
|
|
|
|
87 |
|
88 |
-
# Forward pass
|
89 |
with torch.no_grad():
|
90 |
-
preds = model(
|
91 |
|
92 |
preds_np = preds.numpy()
|
93 |
|
@@ -98,7 +110,6 @@ def predict_properties(smiles: str):
|
|
98 |
for i in range(output_dim)
|
99 |
], axis=1)
|
100 |
|
101 |
-
# Create dictionary of results
|
102 |
results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
|
103 |
return results
|
104 |
|
|
|
40 |
]
|
41 |
return np.array(descriptors, dtype=np.float32)
|
42 |
|
43 |
+
# Transformer regression model definition (must match training)
|
44 |
class TransformerRegressor(nn.Module):
|
45 |
+
def __init__(self, input_dim, embedding_dim, ff_dim, num_layers, output_dim):
|
46 |
super().__init__()
|
47 |
+
self.feat_proj = nn.Linear(input_dim, embedding_dim)
|
48 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
49 |
+
d_model=embedding_dim,
|
50 |
+
nhead=8,
|
51 |
+
dim_feedforward=ff_dim,
|
52 |
+
dropout=0.1,
|
53 |
+
batch_first=True
|
54 |
+
)
|
55 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
56 |
self.regression_head = nn.Sequential(
|
57 |
+
nn.Linear(embedding_dim, 256),
|
58 |
nn.ReLU(),
|
59 |
+
nn.Linear(256, 128),
|
60 |
nn.ReLU(),
|
61 |
+
nn.Linear(128, output_dim)
|
62 |
)
|
63 |
|
64 |
def forward(self, x):
|
65 |
x = self.feat_proj(x)
|
66 |
x = self.transformer_encoder(x)
|
67 |
+
x = x.mean(dim=1)
|
68 |
return self.regression_head(x)
|
69 |
|
70 |
# Model hyperparameters (must match training)
|
71 |
+
embedding_dim = 768
|
72 |
+
descriptor_dim = 1290 # Based on earlier errors. If unsure, use 1290
|
73 |
+
input_dim = embedding_dim + descriptor_dim # 768 + 1290 = 2058
|
74 |
+
ff_dim = 1024
|
75 |
num_layers = 2
|
76 |
+
output_dim = 6
|
77 |
|
78 |
# Load trained model
|
79 |
device = torch.device("cpu")
|
80 |
+
model = TransformerRegressor(input_dim, embedding_dim, ff_dim, num_layers, output_dim)
|
81 |
model.load_state_dict(torch.load("transformer_model.pt", map_location=device))
|
82 |
model.eval()
|
83 |
|
84 |
# Prediction function
|
85 |
def predict_properties(smiles: str):
|
86 |
try:
|
87 |
+
# Compute descriptors
|
88 |
+
descriptors = compute_descriptors(smiles)
|
89 |
+
descriptors_tensor = torch.tensor(descriptors, dtype=torch.float32).unsqueeze(0)
|
90 |
|
91 |
+
# Get ChemBERTa embedding (CLS token)
|
92 |
inputs = tokenizer(smiles, return_tensors="pt")
|
93 |
with torch.no_grad():
|
94 |
outputs = embedding_model(**inputs)
|
95 |
+
embedding = outputs.last_hidden_state[:, 0, :] # (1, 768)
|
96 |
+
|
97 |
+
# Combine features
|
98 |
+
combined = torch.cat([embedding, descriptors_tensor], dim=1).unsqueeze(1) # Shape: (1, 1, 2058)
|
99 |
|
100 |
+
# Forward pass
|
101 |
with torch.no_grad():
|
102 |
+
preds = model(combined)
|
103 |
|
104 |
preds_np = preds.numpy()
|
105 |
|
|
|
110 |
for i in range(output_dim)
|
111 |
], axis=1)
|
112 |
|
|
|
113 |
results = {key: round(val, 4) for key, val in zip(keys, preds_rescaled.flatten())}
|
114 |
return results
|
115 |
|