Spaces:
Runtime error
Runtime error
bug fix
Browse files- README.md +0 -1
- app.py +30 -26
- requirements.txt +8 -6
README.md
CHANGED
@@ -5,7 +5,6 @@ colorFrom: blue
|
|
5 |
colorTo: pink
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.21.0
|
8 |
-
python_version: 3.10.0
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
license: mit
|
|
|
5 |
colorTo: pink
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.21.0
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
app.py
CHANGED
@@ -65,6 +65,32 @@ def predict(model, dataloader):
|
|
65 |
#preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
|
66 |
#preds.drop(columns=['time_idx_y'],inplace=True)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def main():
|
69 |
## Initiate Data
|
70 |
with open('data/parameters.pkl', 'rb') as f:
|
@@ -72,7 +98,7 @@ def main():
|
|
72 |
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
73 |
|
74 |
df = pd.read_pickle('data/test_data.pkl')
|
75 |
-
df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
|
76 |
|
77 |
rain_mapping = {
|
78 |
"Yes" : 1,
|
@@ -110,31 +136,9 @@ def main():
|
|
110 |
|
111 |
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
|
117 |
-
axs[0, 0].set_title('Article Group 1')
|
118 |
-
|
119 |
-
axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
|
120 |
-
axs[0, 1].set_title('Article Group 2')
|
121 |
-
|
122 |
-
axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
|
123 |
-
axs[1, 0].set_title('Article Group 3')
|
124 |
-
|
125 |
-
axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
|
126 |
-
axs[1, 1].set_title('Article Group 4')
|
127 |
-
|
128 |
-
# Adjust spacing between subplots
|
129 |
-
plt.tight_layout()
|
130 |
-
|
131 |
-
for ax in axs.flat:
|
132 |
-
ax.set_xlim(df['Date'].min(), df['Date'].max())
|
133 |
-
ax.set_ylim(df['sales'].min(), df['sales'].max())
|
134 |
-
|
135 |
-
st.pyplot(fig)
|
136 |
-
|
137 |
-
st.button("Forecast Sales", type="primary")
|
138 |
|
139 |
if __name__ == '__main__':
|
140 |
main()
|
|
|
65 |
#preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
|
66 |
#preds.drop(columns=['time_idx_y'],inplace=True)
|
67 |
|
68 |
+
def generate_plot(df, predictions):
|
69 |
+
fig, axs = plt.subplots(2, 2, figsize=(8, 6))
|
70 |
+
|
71 |
+
# Plot scatter plots for each group
|
72 |
+
axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
|
73 |
+
axs[0, 0].set_title('Article Group 1')
|
74 |
+
|
75 |
+
axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
|
76 |
+
axs[0, 1].set_title('Article Group 2')
|
77 |
+
|
78 |
+
axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
|
79 |
+
axs[1, 0].set_title('Article Group 3')
|
80 |
+
|
81 |
+
axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
|
82 |
+
axs[1, 1].set_title('Article Group 4')
|
83 |
+
|
84 |
+
# Adjust spacing between subplots
|
85 |
+
plt.tight_layout()
|
86 |
+
|
87 |
+
for ax in axs.flat:
|
88 |
+
ax.set_xlim(df['Date'].min(), df['Date'].max())
|
89 |
+
ax.set_ylim(df['sales'].min(), df['sales'].max())
|
90 |
+
|
91 |
+
st.pyplot(fig)
|
92 |
+
|
93 |
+
|
94 |
def main():
|
95 |
## Initiate Data
|
96 |
with open('data/parameters.pkl', 'rb') as f:
|
|
|
98 |
model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
|
99 |
|
100 |
df = pd.read_pickle('data/test_data.pkl')
|
101 |
+
df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
|
102 |
|
103 |
rain_mapping = {
|
104 |
"Yes" : 1,
|
|
|
136 |
|
137 |
datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
|
138 |
|
139 |
+
if st.button("Forecast Sales", type="primary"):
|
140 |
+
converted_data = prepare_dataset(parameters, df, rain, temperature, datepicker)
|
141 |
+
generate_plot(converted_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
if __name__ == '__main__':
|
144 |
main()
|
requirements.txt
CHANGED
@@ -22,7 +22,7 @@ decorator==5.1.1
|
|
22 |
executing==1.2.0
|
23 |
fonttools==4.38.0
|
24 |
frozenlist==1.3.3
|
25 |
-
fsspec
|
26 |
future==0.18.3
|
27 |
google-auth==2.16.0
|
28 |
google-auth-oauthlib==0.4.6
|
@@ -49,7 +49,7 @@ multidict==6.0.4
|
|
49 |
nest-asyncio==1.5.6
|
50 |
numpy==1.23.5
|
51 |
oauthlib==3.2.2
|
52 |
-
optuna
|
53 |
packaging==23.0
|
54 |
pandas==1.5.2
|
55 |
parso==0.8.3
|
@@ -72,8 +72,8 @@ Pygments==2.14.0
|
|
72 |
pyparsing==3.0.9
|
73 |
pyperclip==1.8.2
|
74 |
python-dateutil==2.8.2
|
75 |
-
pytorch-forecasting
|
76 |
-
pytorch-lightning
|
77 |
pytz==2022.7.1
|
78 |
PyYAML==6.0
|
79 |
pyzmq==25.0.0
|
@@ -81,7 +81,7 @@ requests==2.28.2
|
|
81 |
requests-futures==1.0.0
|
82 |
requests-oauthlib==1.3.1
|
83 |
rsa==4.9
|
84 |
-
scikit-learn==1.
|
85 |
scipy==1.10.0
|
86 |
six==1.16.0
|
87 |
SQLAlchemy==1.4.46
|
@@ -93,8 +93,10 @@ tensorboard-data-server==0.6.1
|
|
93 |
tensorboard-plugin-wit==1.8.1
|
94 |
tensorboardX==2.5.1
|
95 |
threadpoolctl==3.1.0
|
96 |
-
torch
|
|
|
97 |
torchmetrics==0.11.0
|
|
|
98 |
tornado==6.2
|
99 |
tqdm==4.64.1
|
100 |
traitlets==5.9.0
|
|
|
22 |
executing==1.2.0
|
23 |
fonttools==4.38.0
|
24 |
frozenlist==1.3.3
|
25 |
+
fsspec==2022.11.0
|
26 |
future==0.18.3
|
27 |
google-auth==2.16.0
|
28 |
google-auth-oauthlib==0.4.6
|
|
|
49 |
nest-asyncio==1.5.6
|
50 |
numpy==1.23.5
|
51 |
oauthlib==3.2.2
|
52 |
+
optuna==2.10.1
|
53 |
packaging==23.0
|
54 |
pandas==1.5.2
|
55 |
parso==0.8.3
|
|
|
72 |
pyparsing==3.0.9
|
73 |
pyperclip==1.8.2
|
74 |
python-dateutil==2.8.2
|
75 |
+
pytorch-forecasting==0.10.3
|
76 |
+
pytorch-lightning==1.9.0
|
77 |
pytz==2022.7.1
|
78 |
PyYAML==6.0
|
79 |
pyzmq==25.0.0
|
|
|
81 |
requests-futures==1.0.0
|
82 |
requests-oauthlib==1.3.1
|
83 |
rsa==4.9
|
84 |
+
scikit-learn==1.1.3
|
85 |
scipy==1.10.0
|
86 |
six==1.16.0
|
87 |
SQLAlchemy==1.4.46
|
|
|
93 |
tensorboard-plugin-wit==1.8.1
|
94 |
tensorboardX==2.5.1
|
95 |
threadpoolctl==3.1.0
|
96 |
+
torch==1.10.2
|
97 |
+
torchaudio==0.10.2
|
98 |
torchmetrics==0.11.0
|
99 |
+
torchvision==0.11.3
|
100 |
tornado==6.2
|
101 |
tqdm==4.64.1
|
102 |
traitlets==5.9.0
|