transpolymer commited on
Commit
e05cbe5
·
verified ·
1 Parent(s): 799aa56

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +17 -9
prediction.py CHANGED
@@ -20,15 +20,15 @@ torch.backends.cudnn.benchmark = False
20
  # Check if CUDA is available for GPU acceleration
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
23
- # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
24
- @st.cache_resource
25
- def load_chemberta():
26
- tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
27
- model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
28
- model.eval()
29
- model.to(device)
30
- return tokenizer, model
31
- tokenizer, chemberta = load_chemberta()
32
 
33
 
34
  # ------------------------ Load Scalers ------------------------
@@ -67,6 +67,13 @@ class TransformerRegressor(nn.Module):
67
  x = self.transformer_encoder(x)
68
  x = x.mean(dim=1)
69
  return self.regression_head(x)
 
 
 
 
 
 
 
70
 
71
  @st.cache_resource
72
  def load_model():
@@ -80,6 +87,7 @@ def load_model():
80
  # Call them to load the actual models
81
  tokenizer, chemberta = load_chemberta()
82
  model = load_model()
 
83
  # ------------------------ Descriptors ------------------------
84
  def compute_descriptors(smiles: str):
85
  mol = Chem.MolFromSmiles(smiles)
 
20
  # Check if CUDA is available for GPU acceleration
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
23
+ # # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
24
+ # @st.cache_resource
25
+ # def load_chemberta():
26
+ # tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
27
+ # model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
28
+ # model.eval()
29
+ # model.to(device)
30
+ # return tokenizer, model
31
+ # tokenizer, chemberta = load_chemberta()
32
 
33
 
34
  # ------------------------ Load Scalers ------------------------
 
67
  x = self.transformer_encoder(x)
68
  x = x.mean(dim=1)
69
  return self.regression_head(x)
70
+ @st.cache_resource
71
+ def load_chemberta():
72
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
73
+ model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
74
+ model.eval()
75
+ model.to(device)
76
+ return tokenizer, model
77
 
78
  @st.cache_resource
79
  def load_model():
 
87
  # Call them to load the actual models
88
  tokenizer, chemberta = load_chemberta()
89
  model = load_model()
90
+
91
  # ------------------------ Descriptors ------------------------
92
  def compute_descriptors(smiles: str):
93
  mol = Chem.MolFromSmiles(smiles)