Spaces:
Running
Running
Update prediction.py
Browse files- prediction.py +2 -1
prediction.py
CHANGED
@@ -69,7 +69,8 @@ class TransformerRegressor(nn.Module):
|
|
69 |
@st.cache_resource
|
70 |
def load_model():
|
71 |
model = TransformerRegressor()
|
72 |
-
|
|
|
73 |
model.eval()
|
74 |
return model
|
75 |
|
|
|
69 |
@st.cache_resource
|
70 |
def load_model():
|
71 |
model = TransformerRegressor()
|
72 |
+
state_dict = torch.load("transformer_model.bin", map_location=torch.device("cpu"))
|
73 |
+
model.load_state_dict(state_dict)
|
74 |
model.eval()
|
75 |
return model
|
76 |
|