Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +6 -9
prediction.py
CHANGED
@@ -70,19 +70,16 @@ class TransformerRegressor(nn.Module):
|
|
70 |
|
71 |
@st.cache_resource
|
72 |
def load_model():
|
73 |
-
# Initialize the model architecture first
|
74 |
model = TransformerRegressor()
|
75 |
-
|
76 |
-
# Load the state_dict (weights) from the saved model file
|
77 |
-
state_dict = torch.load("transformer_model(1).bin", map_location=device) # Ensure loading on the correct device
|
78 |
-
|
79 |
-
# Load the state_dict into the model
|
80 |
model.load_state_dict(state_dict)
|
81 |
-
|
82 |
-
# Set the model to evaluation mode
|
83 |
model.eval()
|
84 |
-
model.to(device)
|
85 |
return model
|
|
|
|
|
|
|
|
|
86 |
# ------------------------ Descriptors ------------------------
|
87 |
def compute_descriptors(smiles: str):
|
88 |
mol = Chem.MolFromSmiles(smiles)
|
|
|
70 |
|
71 |
@st.cache_resource
|
72 |
def load_model():
|
|
|
73 |
model = TransformerRegressor()
|
74 |
+
state_dict = torch.load("transformer_model(1).bin", map_location=device)
|
|
|
|
|
|
|
|
|
75 |
model.load_state_dict(state_dict)
|
|
|
|
|
76 |
model.eval()
|
77 |
+
model.to(device)
|
78 |
return model
|
79 |
+
|
80 |
+
# Call them to load the actual models
|
81 |
+
tokenizer, chemberta = load_chemberta()
|
82 |
+
model = load_model()
|
83 |
# ------------------------ Descriptors ------------------------
|
84 |
def compute_descriptors(smiles: str):
|
85 |
mol = Chem.MolFromSmiles(smiles)
|