Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +5 -6
prediction.py
CHANGED
@@ -38,9 +38,9 @@ scalers = {
|
|
38 |
|
39 |
# Transformer model
|
40 |
class TransformerRegressor(nn.Module):
|
41 |
-
def __init__(self,
|
42 |
super().__init__()
|
43 |
-
self.feat_proj = nn.Linear(
|
44 |
encoder_layer = nn.TransformerEncoderLayer(
|
45 |
d_model=embedding_dim, nhead=8, dim_feedforward=ff_dim,
|
46 |
dropout=0.1, batch_first=True
|
@@ -98,14 +98,13 @@ def get_morgan_fingerprint(smiles, radius=2, n_bits=2048):
|
|
98 |
if mol is None:
|
99 |
raise ValueError("Invalid SMILES string.")
|
100 |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
|
101 |
-
return np.array(fp, dtype=np.float32).reshape(
|
102 |
|
103 |
# ChemBERTa embedding
|
104 |
-
def get_chemberta_embedding(smiles: str, tokenizer,
|
105 |
inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
|
106 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
107 |
with torch.no_grad():
|
108 |
-
outputs =
|
109 |
return outputs.last_hidden_state.mean(dim=1).to(device)
|
110 |
|
111 |
# Save to DB
|
|
|
38 |
|
39 |
# Transformer model
|
40 |
class TransformerRegressor(nn.Module):
|
41 |
+
def __init__(self, feat_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
|
42 |
super().__init__()
|
43 |
+
self.feat_proj = nn.Linear(feat_dim, embedding_dim)
|
44 |
encoder_layer = nn.TransformerEncoderLayer(
|
45 |
d_model=embedding_dim, nhead=8, dim_feedforward=ff_dim,
|
46 |
dropout=0.1, batch_first=True
|
|
|
98 |
if mol is None:
|
99 |
raise ValueError("Invalid SMILES string.")
|
100 |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
|
101 |
+
return np.array(fp, dtype=np.float32).reshape(1,-1)
|
102 |
|
103 |
# ChemBERTa embedding
|
104 |
+
def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
|
105 |
inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
|
|
|
106 |
with torch.no_grad():
|
107 |
+
outputs = chemberta(**inputs)
|
108 |
return outputs.last_hidden_state.mean(dim=1).to(device)
|
109 |
|
110 |
# Save to DB
|