transpolymer commited on
Commit
a715cd6
·
verified ·
1 Parent(s): 23348f0

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +5 -6
prediction.py CHANGED
@@ -38,9 +38,9 @@ scalers = {
38
 
39
  # Transformer model
40
  class TransformerRegressor(nn.Module):
41
- def __init__(self, input_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
42
  super().__init__()
43
- self.feat_proj = nn.Linear(input_dim, embedding_dim)
44
  encoder_layer = nn.TransformerEncoderLayer(
45
  d_model=embedding_dim, nhead=8, dim_feedforward=ff_dim,
46
  dropout=0.1, batch_first=True
@@ -98,14 +98,13 @@ def get_morgan_fingerprint(smiles, radius=2, n_bits=2048):
98
  if mol is None:
99
  raise ValueError("Invalid SMILES string.")
100
  fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
101
- return np.array(fp, dtype=np.float32).reshape(-1,1)
102
 
103
  # ChemBERTa embedding
104
- def get_chemberta_embedding(smiles: str, tokenizer, model):
105
  inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
106
- inputs = {k: v.to(device) for k, v in inputs.items()}
107
  with torch.no_grad():
108
- outputs = model(**inputs)
109
  return outputs.last_hidden_state.mean(dim=1).to(device)
110
 
111
  # Save to DB
 
38
 
39
  # Transformer model
40
  class TransformerRegressor(nn.Module):
41
+ def __init__(self, feat_dim=2058, embedding_dim=768, ff_dim=1024, num_layers=2, output_dim=6):
42
  super().__init__()
43
+ self.feat_proj = nn.Linear(feat_dim, embedding_dim)
44
  encoder_layer = nn.TransformerEncoderLayer(
45
  d_model=embedding_dim, nhead=8, dim_feedforward=ff_dim,
46
  dropout=0.1, batch_first=True
 
98
  if mol is None:
99
  raise ValueError("Invalid SMILES string.")
100
  fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
101
+ return np.array(fp, dtype=np.float32).reshape(1,-1)
102
 
103
  # ChemBERTa embedding
104
+ def get_chemberta_embedding(smiles: str, tokenizer, chemberta):
105
  inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True)
 
106
  with torch.no_grad():
107
+ outputs = chemberta(**inputs)
108
  return outputs.last_hidden_state.mean(dim=1).to(device)
109
 
110
  # Save to DB