transpolymer commited on
Commit
c54640c
·
verified ·
1 Parent(s): cfc37e3

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +3 -1
prediction.py CHANGED
@@ -20,13 +20,15 @@ torch.backends.cudnn.benchmark = False
20
  # Check if CUDA is available for GPU acceleration
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
 
23
  @st.cache_resource
24
  def load_chemberta():
25
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
26
  model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
27
  model.eval()
28
- model.to(device) # Send model to GPU if available
29
  return tokenizer, model
 
30
 
31
 
32
  # ------------------------ Load Scalers ------------------------
 
20
  # Check if CUDA is available for GPU acceleration
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
23
+ # ------------------------ Load ChemBERTa Model + Tokenizer ------------------------
24
  @st.cache_resource
25
  def load_chemberta():
26
  tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
27
  model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
28
  model.eval()
29
+ model.to(device)
30
  return tokenizer, model
31
+ tokenizer, chemberta = load_chemberta()
32
 
33
 
34
  # ------------------------ Load Scalers ------------------------