leuschnm commited on
Commit
89c3cca
·
1 Parent(s): 1adf0d3
Files changed (3) hide show
  1. README.md +0 -1
  2. app.py +30 -26
  3. requirements.txt +8 -6
README.md CHANGED
@@ -5,7 +5,6 @@ colorFrom: blue
5
  colorTo: pink
6
  sdk: streamlit
7
  sdk_version: 1.21.0
8
- python_version: 3.10.0
9
  app_file: app.py
10
  pinned: false
11
  license: mit
 
5
  colorTo: pink
6
  sdk: streamlit
7
  sdk_version: 1.21.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -65,6 +65,32 @@ def predict(model, dataloader):
65
  #preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
66
  #preds.drop(columns=['time_idx_y'],inplace=True)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def main():
69
  ## Initiate Data
70
  with open('data/parameters.pkl', 'rb') as f:
@@ -72,7 +98,7 @@ def main():
72
  model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
73
 
74
  df = pd.read_pickle('data/test_data.pkl')
75
- df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
76
 
77
  rain_mapping = {
78
  "Yes" : 1,
@@ -110,31 +136,9 @@ def main():
110
 
111
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
112
 
113
- fig, axs = plt.subplots(2, 2, figsize=(8, 6))
114
-
115
- # Plot scatter plots for each group
116
- axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
117
- axs[0, 0].set_title('Article Group 1')
118
-
119
- axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
120
- axs[0, 1].set_title('Article Group 2')
121
-
122
- axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
123
- axs[1, 0].set_title('Article Group 3')
124
-
125
- axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
126
- axs[1, 1].set_title('Article Group 4')
127
-
128
- # Adjust spacing between subplots
129
- plt.tight_layout()
130
-
131
- for ax in axs.flat:
132
- ax.set_xlim(df['Date'].min(), df['Date'].max())
133
- ax.set_ylim(df['sales'].min(), df['sales'].max())
134
-
135
- st.pyplot(fig)
136
-
137
- st.button("Forecast Sales", type="primary")
138
 
139
  if __name__ == '__main__':
140
  main()
 
65
  #preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
66
  #preds.drop(columns=['time_idx_y'],inplace=True)
67
 
68
+ def generate_plot(df, predictions):
69
+ fig, axs = plt.subplots(2, 2, figsize=(8, 6))
70
+
71
+ # Plot scatter plots for each group
72
+ axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')
73
+ axs[0, 0].set_title('Article Group 1')
74
+
75
+ axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')
76
+ axs[0, 1].set_title('Article Group 2')
77
+
78
+ axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')
79
+ axs[1, 0].set_title('Article Group 3')
80
+
81
+ axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')
82
+ axs[1, 1].set_title('Article Group 4')
83
+
84
+ # Adjust spacing between subplots
85
+ plt.tight_layout()
86
+
87
+ for ax in axs.flat:
88
+ ax.set_xlim(df['Date'].min(), df['Date'].max())
89
+ ax.set_ylim(df['sales'].min(), df['sales'].max())
90
+
91
+ st.pyplot(fig)
92
+
93
+
94
  def main():
95
  ## Initiate Data
96
  with open('data/parameters.pkl', 'rb') as f:
 
98
  model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))
99
 
100
  df = pd.read_pickle('data/test_data.pkl')
101
+ df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))]
102
 
103
  rain_mapping = {
104
  "Yes" : 1,
 
136
 
137
  datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
138
 
139
+ if st.button("Forecast Sales", type="primary"):
140
+ converted_data = prepare_dataset(parameters, df, rain, temperature, datepicker)
141
+ generate_plot(converted_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  if __name__ == '__main__':
144
  main()
requirements.txt CHANGED
@@ -22,7 +22,7 @@ decorator==5.1.1
22
  executing==1.2.0
23
  fonttools==4.38.0
24
  frozenlist==1.3.3
25
- fsspec
26
  future==0.18.3
27
  google-auth==2.16.0
28
  google-auth-oauthlib==0.4.6
@@ -49,7 +49,7 @@ multidict==6.0.4
49
  nest-asyncio==1.5.6
50
  numpy==1.23.5
51
  oauthlib==3.2.2
52
- optuna
53
  packaging==23.0
54
  pandas==1.5.2
55
  parso==0.8.3
@@ -72,8 +72,8 @@ Pygments==2.14.0
72
  pyparsing==3.0.9
73
  pyperclip==1.8.2
74
  python-dateutil==2.8.2
75
- pytorch-forecasting
76
- pytorch-lightning
77
  pytz==2022.7.1
78
  PyYAML==6.0
79
  pyzmq==25.0.0
@@ -81,7 +81,7 @@ requests==2.28.2
81
  requests-futures==1.0.0
82
  requests-oauthlib==1.3.1
83
  rsa==4.9
84
- scikit-learn==1.2.2
85
  scipy==1.10.0
86
  six==1.16.0
87
  SQLAlchemy==1.4.46
@@ -93,8 +93,10 @@ tensorboard-data-server==0.6.1
93
  tensorboard-plugin-wit==1.8.1
94
  tensorboardX==2.5.1
95
  threadpoolctl==3.1.0
96
- torch
 
97
  torchmetrics==0.11.0
 
98
  tornado==6.2
99
  tqdm==4.64.1
100
  traitlets==5.9.0
 
22
  executing==1.2.0
23
  fonttools==4.38.0
24
  frozenlist==1.3.3
25
+ fsspec==2022.11.0
26
  future==0.18.3
27
  google-auth==2.16.0
28
  google-auth-oauthlib==0.4.6
 
49
  nest-asyncio==1.5.6
50
  numpy==1.23.5
51
  oauthlib==3.2.2
52
+ optuna==2.10.1
53
  packaging==23.0
54
  pandas==1.5.2
55
  parso==0.8.3
 
72
  pyparsing==3.0.9
73
  pyperclip==1.8.2
74
  python-dateutil==2.8.2
75
+ pytorch-forecasting==0.10.3
76
+ pytorch-lightning==1.9.0
77
  pytz==2022.7.1
78
  PyYAML==6.0
79
  pyzmq==25.0.0
 
81
  requests-futures==1.0.0
82
  requests-oauthlib==1.3.1
83
  rsa==4.9
84
+ scikit-learn==1.1.3
85
  scipy==1.10.0
86
  six==1.16.0
87
  SQLAlchemy==1.4.46
 
93
  tensorboard-plugin-wit==1.8.1
94
  tensorboardX==2.5.1
95
  threadpoolctl==3.1.0
96
+ torch==1.10.2
97
+ torchaudio==0.10.2
98
  torchmetrics==0.11.0
99
+ torchvision==0.11.3
100
  tornado==6.2
101
  tqdm==4.64.1
102
  traitlets==5.9.0