Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +1 -4
prediction.py
CHANGED
@@ -68,12 +68,9 @@ class TransformerRegressor(nn.Module):
|
|
68 |
# ------------------------ Load Model ------------------------
|
69 |
@st.cache_resource
|
70 |
def load_model():
|
71 |
-
model =
|
72 |
-
state_dict = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
|
73 |
-
model.load_state_dict(state_dict)
|
74 |
model.eval()
|
75 |
return model
|
76 |
-
|
77 |
model = load_model()
|
78 |
|
79 |
# ------------------------ Descriptors ------------------------
|
|
|
68 |
# ------------------------ Load Model ------------------------
|
69 |
@st.cache_resource
|
70 |
def load_model():
|
71 |
+
model = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
|
|
|
|
|
72 |
model.eval()
|
73 |
return model
|
|
|
74 |
model = load_model()
|
75 |
|
76 |
# ------------------------ Descriptors ------------------------
|