leuschnm commited on
Commit
c850717
·
1 Parent(s): f51ee49
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/parameters.pkl filter=lfs diff=lfs merge=lfs -text
37
+ data/test_data.pkl filter=lfs diff=lfs merge=lfs -text
38
+ model/tft_check.ckpt filter=lfs diff=lfs merge=lfs -text
Untitled.ipynb ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 18,
6
+ "id": "40d0c64c-1de1-47ed-aa66-7b28d9e8fd1f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pickle \n",
11
+ "import pandas as pd \n",
12
+ "import datetime"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 13,
18
+ "id": "5043d237-0287-4705-bfd5-73b880b36def",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "df = pd.read_pickle('data/test_data.pkl')\n",
23
+ "df = df.loc[(df[\"Branch\"] == \"15\") & (df[\"Group\"].isin([\"6\",\"7\",\"4\",\"1\"]))]"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 14,
29
+ "id": "dce8096f-23d4-4075-8654-6693632c45bc",
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "data": {
34
+ "text/html": [
35
+ "<div>\n",
36
+ "<style scoped>\n",
37
+ " .dataframe tbody tr th:only-of-type {\n",
38
+ " vertical-align: middle;\n",
39
+ " }\n",
40
+ "\n",
41
+ " .dataframe tbody tr th {\n",
42
+ " vertical-align: top;\n",
43
+ " }\n",
44
+ "\n",
45
+ " .dataframe thead th {\n",
46
+ " text-align: right;\n",
47
+ " }\n",
48
+ "</style>\n",
49
+ "<table border=\"1\" class=\"dataframe\">\n",
50
+ " <thead>\n",
51
+ " <tr style=\"text-align: right;\">\n",
52
+ " <th></th>\n",
53
+ " <th>sales</th>\n",
54
+ " <th>DayInYear</th>\n",
55
+ " <th>time_idx</th>\n",
56
+ " <th>Wahl</th>\n",
57
+ " <th>Baustelle</th>\n",
58
+ " <th>MontagLangesWE</th>\n",
59
+ " <th>FreitagLangesWE</th>\n",
60
+ " <th>nosale</th>\n",
61
+ " <th>holiday</th>\n",
62
+ " <th>AufSommerzeit</th>\n",
63
+ " <th>...</th>\n",
64
+ " <th>Branch</th>\n",
65
+ " <th>Weekday</th>\n",
66
+ " <th>Date</th>\n",
67
+ " <th>MTXWTH_Day_precip</th>\n",
68
+ " <th>MTXWTH_Temp_max</th>\n",
69
+ " <th>MTXWTH_Temp_min</th>\n",
70
+ " <th>Start</th>\n",
71
+ " <th>End</th>\n",
72
+ " <th>ShiftLength</th>\n",
73
+ " <th>weight</th>\n",
74
+ " </tr>\n",
75
+ " </thead>\n",
76
+ " <tbody>\n",
77
+ " <tr>\n",
78
+ " <th>270300</th>\n",
79
+ " <td>1600.9030</td>\n",
80
+ " <td>177</td>\n",
81
+ " <td>2369</td>\n",
82
+ " <td>0.0</td>\n",
83
+ " <td>0.0</td>\n",
84
+ " <td>0.0</td>\n",
85
+ " <td>0.0</td>\n",
86
+ " <td>0</td>\n",
87
+ " <td>none</td>\n",
88
+ " <td>0.0</td>\n",
89
+ " <td>...</td>\n",
90
+ " <td>15</td>\n",
91
+ " <td>6</td>\n",
92
+ " <td>2022-06-26</td>\n",
93
+ " <td>0.0</td>\n",
94
+ " <td>28.52</td>\n",
95
+ " <td>17.47</td>\n",
96
+ " <td>7.0</td>\n",
97
+ " <td>10.983333</td>\n",
98
+ " <td>240.0</td>\n",
99
+ " <td>1</td>\n",
100
+ " </tr>\n",
101
+ " <tr>\n",
102
+ " <th>270301</th>\n",
103
+ " <td>1811.1958</td>\n",
104
+ " <td>178</td>\n",
105
+ " <td>2370</td>\n",
106
+ " <td>0.0</td>\n",
107
+ " <td>0.0</td>\n",
108
+ " <td>0.0</td>\n",
109
+ " <td>0.0</td>\n",
110
+ " <td>0</td>\n",
111
+ " <td>none</td>\n",
112
+ " <td>0.0</td>\n",
113
+ " <td>...</td>\n",
114
+ " <td>15</td>\n",
115
+ " <td>0</td>\n",
116
+ " <td>2022-06-27</td>\n",
117
+ " <td>0.0</td>\n",
118
+ " <td>25.75</td>\n",
119
+ " <td>16.70</td>\n",
120
+ " <td>6.0</td>\n",
121
+ " <td>13.983333</td>\n",
122
+ " <td>480.0</td>\n",
123
+ " <td>1</td>\n",
124
+ " </tr>\n",
125
+ " <tr>\n",
126
+ " <th>270302</th>\n",
127
+ " <td>1784.2916</td>\n",
128
+ " <td>179</td>\n",
129
+ " <td>2371</td>\n",
130
+ " <td>0.0</td>\n",
131
+ " <td>0.0</td>\n",
132
+ " <td>0.0</td>\n",
133
+ " <td>0.0</td>\n",
134
+ " <td>0</td>\n",
135
+ " <td>none</td>\n",
136
+ " <td>0.0</td>\n",
137
+ " <td>...</td>\n",
138
+ " <td>15</td>\n",
139
+ " <td>1</td>\n",
140
+ " <td>2022-06-28</td>\n",
141
+ " <td>0.0</td>\n",
142
+ " <td>23.57</td>\n",
143
+ " <td>14.17</td>\n",
144
+ " <td>6.0</td>\n",
145
+ " <td>13.983333</td>\n",
146
+ " <td>480.0</td>\n",
147
+ " <td>1</td>\n",
148
+ " </tr>\n",
149
+ " <tr>\n",
150
+ " <th>270303</th>\n",
151
+ " <td>1757.3488</td>\n",
152
+ " <td>180</td>\n",
153
+ " <td>2372</td>\n",
154
+ " <td>0.0</td>\n",
155
+ " <td>0.0</td>\n",
156
+ " <td>0.0</td>\n",
157
+ " <td>0.0</td>\n",
158
+ " <td>0</td>\n",
159
+ " <td>none</td>\n",
160
+ " <td>0.0</td>\n",
161
+ " <td>...</td>\n",
162
+ " <td>15</td>\n",
163
+ " <td>2</td>\n",
164
+ " <td>2022-06-29</td>\n",
165
+ " <td>0.0</td>\n",
166
+ " <td>26.81</td>\n",
167
+ " <td>13.09</td>\n",
168
+ " <td>6.0</td>\n",
169
+ " <td>13.983333</td>\n",
170
+ " <td>480.0</td>\n",
171
+ " <td>1</td>\n",
172
+ " </tr>\n",
173
+ " <tr>\n",
174
+ " <th>270304</th>\n",
175
+ " <td>1741.0982</td>\n",
176
+ " <td>181</td>\n",
177
+ " <td>2373</td>\n",
178
+ " <td>0.0</td>\n",
179
+ " <td>0.0</td>\n",
180
+ " <td>0.0</td>\n",
181
+ " <td>0.0</td>\n",
182
+ " <td>0</td>\n",
183
+ " <td>none</td>\n",
184
+ " <td>0.0</td>\n",
185
+ " <td>...</td>\n",
186
+ " <td>15</td>\n",
187
+ " <td>3</td>\n",
188
+ " <td>2022-06-30</td>\n",
189
+ " <td>0.0</td>\n",
190
+ " <td>27.26</td>\n",
191
+ " <td>15.00</td>\n",
192
+ " <td>6.0</td>\n",
193
+ " <td>13.983333</td>\n",
194
+ " <td>480.0</td>\n",
195
+ " <td>1</td>\n",
196
+ " </tr>\n",
197
+ " <tr>\n",
198
+ " <th>...</th>\n",
199
+ " <td>...</td>\n",
200
+ " <td>...</td>\n",
201
+ " <td>...</td>\n",
202
+ " <td>...</td>\n",
203
+ " <td>...</td>\n",
204
+ " <td>...</td>\n",
205
+ " <td>...</td>\n",
206
+ " <td>...</td>\n",
207
+ " <td>...</td>\n",
208
+ " <td>...</td>\n",
209
+ " <td>...</td>\n",
210
+ " <td>...</td>\n",
211
+ " <td>...</td>\n",
212
+ " <td>...</td>\n",
213
+ " <td>...</td>\n",
214
+ " <td>...</td>\n",
215
+ " <td>...</td>\n",
216
+ " <td>...</td>\n",
217
+ " <td>...</td>\n",
218
+ " <td>...</td>\n",
219
+ " <td>...</td>\n",
220
+ " </tr>\n",
221
+ " <tr>\n",
222
+ " <th>287065</th>\n",
223
+ " <td>1643.1700</td>\n",
224
+ " <td>173</td>\n",
225
+ " <td>2730</td>\n",
226
+ " <td>0.0</td>\n",
227
+ " <td>0.0</td>\n",
228
+ " <td>0.0</td>\n",
229
+ " <td>0.0</td>\n",
230
+ " <td>0</td>\n",
231
+ " <td>none</td>\n",
232
+ " <td>0.0</td>\n",
233
+ " <td>...</td>\n",
234
+ " <td>15</td>\n",
235
+ " <td>3</td>\n",
236
+ " <td>2023-06-22</td>\n",
237
+ " <td>0.0</td>\n",
238
+ " <td>26.93</td>\n",
239
+ " <td>13.06</td>\n",
240
+ " <td>6.0</td>\n",
241
+ " <td>16.983333</td>\n",
242
+ " <td>660.0</td>\n",
243
+ " <td>1</td>\n",
244
+ " </tr>\n",
245
+ " <tr>\n",
246
+ " <th>287066</th>\n",
247
+ " <td>1597.3518</td>\n",
248
+ " <td>174</td>\n",
249
+ " <td>2731</td>\n",
250
+ " <td>0.0</td>\n",
251
+ " <td>0.0</td>\n",
252
+ " <td>0.0</td>\n",
253
+ " <td>0.0</td>\n",
254
+ " <td>0</td>\n",
255
+ " <td>none</td>\n",
256
+ " <td>0.0</td>\n",
257
+ " <td>...</td>\n",
258
+ " <td>15</td>\n",
259
+ " <td>4</td>\n",
260
+ " <td>2023-06-23</td>\n",
261
+ " <td>1.0</td>\n",
262
+ " <td>23.99</td>\n",
263
+ " <td>15.98</td>\n",
264
+ " <td>6.0</td>\n",
265
+ " <td>16.983333</td>\n",
266
+ " <td>660.0</td>\n",
267
+ " <td>1</td>\n",
268
+ " </tr>\n",
269
+ " <tr>\n",
270
+ " <th>287067</th>\n",
271
+ " <td>1683.6228</td>\n",
272
+ " <td>175</td>\n",
273
+ " <td>2732</td>\n",
274
+ " <td>0.0</td>\n",
275
+ " <td>0.0</td>\n",
276
+ " <td>0.0</td>\n",
277
+ " <td>0.0</td>\n",
278
+ " <td>0</td>\n",
279
+ " <td>none</td>\n",
280
+ " <td>0.0</td>\n",
281
+ " <td>...</td>\n",
282
+ " <td>15</td>\n",
283
+ " <td>5</td>\n",
284
+ " <td>2023-06-24</td>\n",
285
+ " <td>0.0</td>\n",
286
+ " <td>25.99</td>\n",
287
+ " <td>12.04</td>\n",
288
+ " <td>6.0</td>\n",
289
+ " <td>15.983333</td>\n",
290
+ " <td>600.0</td>\n",
291
+ " <td>1</td>\n",
292
+ " </tr>\n",
293
+ " <tr>\n",
294
+ " <th>287068</th>\n",
295
+ " <td>1785.2180</td>\n",
296
+ " <td>176</td>\n",
297
+ " <td>2733</td>\n",
298
+ " <td>0.0</td>\n",
299
+ " <td>0.0</td>\n",
300
+ " <td>0.0</td>\n",
301
+ " <td>0.0</td>\n",
302
+ " <td>0</td>\n",
303
+ " <td>none</td>\n",
304
+ " <td>0.0</td>\n",
305
+ " <td>...</td>\n",
306
+ " <td>15</td>\n",
307
+ " <td>6</td>\n",
308
+ " <td>2023-06-25</td>\n",
309
+ " <td>0.0</td>\n",
310
+ " <td>28.99</td>\n",
311
+ " <td>15.02</td>\n",
312
+ " <td>7.0</td>\n",
313
+ " <td>15.983333</td>\n",
314
+ " <td>540.0</td>\n",
315
+ " <td>1</td>\n",
316
+ " </tr>\n",
317
+ " <tr>\n",
318
+ " <th>287069</th>\n",
319
+ " <td>1589.9020</td>\n",
320
+ " <td>177</td>\n",
321
+ " <td>2734</td>\n",
322
+ " <td>0.0</td>\n",
323
+ " <td>0.0</td>\n",
324
+ " <td>0.0</td>\n",
325
+ " <td>0.0</td>\n",
326
+ " <td>0</td>\n",
327
+ " <td>none</td>\n",
328
+ " <td>0.0</td>\n",
329
+ " <td>...</td>\n",
330
+ " <td>15</td>\n",
331
+ " <td>0</td>\n",
332
+ " <td>2023-06-26</td>\n",
333
+ " <td>0.0</td>\n",
334
+ " <td>27.96</td>\n",
335
+ " <td>17.01</td>\n",
336
+ " <td>6.0</td>\n",
337
+ " <td>16.983333</td>\n",
338
+ " <td>660.0</td>\n",
339
+ " <td>1</td>\n",
340
+ " </tr>\n",
341
+ " </tbody>\n",
342
+ "</table>\n",
343
+ "<p>1464 rows × 22 columns</p>\n",
344
+ "</div>"
345
+ ],
346
+ "text/plain": [
347
+ " sales DayInYear time_idx Wahl Baustelle MontagLangesWE \\\n",
348
+ "270300 1600.9030 177 2369 0.0 0.0 0.0 \n",
349
+ "270301 1811.1958 178 2370 0.0 0.0 0.0 \n",
350
+ "270302 1784.2916 179 2371 0.0 0.0 0.0 \n",
351
+ "270303 1757.3488 180 2372 0.0 0.0 0.0 \n",
352
+ "270304 1741.0982 181 2373 0.0 0.0 0.0 \n",
353
+ "... ... ... ... ... ... ... \n",
354
+ "287065 1643.1700 173 2730 0.0 0.0 0.0 \n",
355
+ "287066 1597.3518 174 2731 0.0 0.0 0.0 \n",
356
+ "287067 1683.6228 175 2732 0.0 0.0 0.0 \n",
357
+ "287068 1785.2180 176 2733 0.0 0.0 0.0 \n",
358
+ "287069 1589.9020 177 2734 0.0 0.0 0.0 \n",
359
+ "\n",
360
+ " FreitagLangesWE nosale holiday AufSommerzeit ... Branch Weekday \\\n",
361
+ "270300 0.0 0 none 0.0 ... 15 6 \n",
362
+ "270301 0.0 0 none 0.0 ... 15 0 \n",
363
+ "270302 0.0 0 none 0.0 ... 15 1 \n",
364
+ "270303 0.0 0 none 0.0 ... 15 2 \n",
365
+ "270304 0.0 0 none 0.0 ... 15 3 \n",
366
+ "... ... ... ... ... ... ... ... \n",
367
+ "287065 0.0 0 none 0.0 ... 15 3 \n",
368
+ "287066 0.0 0 none 0.0 ... 15 4 \n",
369
+ "287067 0.0 0 none 0.0 ... 15 5 \n",
370
+ "287068 0.0 0 none 0.0 ... 15 6 \n",
371
+ "287069 0.0 0 none 0.0 ... 15 0 \n",
372
+ "\n",
373
+ " Date MTXWTH_Day_precip MTXWTH_Temp_max MTXWTH_Temp_min Start \\\n",
374
+ "270300 2022-06-26 0.0 28.52 17.47 7.0 \n",
375
+ "270301 2022-06-27 0.0 25.75 16.70 6.0 \n",
376
+ "270302 2022-06-28 0.0 23.57 14.17 6.0 \n",
377
+ "270303 2022-06-29 0.0 26.81 13.09 6.0 \n",
378
+ "270304 2022-06-30 0.0 27.26 15.00 6.0 \n",
379
+ "... ... ... ... ... ... \n",
380
+ "287065 2023-06-22 0.0 26.93 13.06 6.0 \n",
381
+ "287066 2023-06-23 1.0 23.99 15.98 6.0 \n",
382
+ "287067 2023-06-24 0.0 25.99 12.04 6.0 \n",
383
+ "287068 2023-06-25 0.0 28.99 15.02 7.0 \n",
384
+ "287069 2023-06-26 0.0 27.96 17.01 6.0 \n",
385
+ "\n",
386
+ " End ShiftLength weight \n",
387
+ "270300 10.983333 240.0 1 \n",
388
+ "270301 13.983333 480.0 1 \n",
389
+ "270302 13.983333 480.0 1 \n",
390
+ "270303 13.983333 480.0 1 \n",
391
+ "270304 13.983333 480.0 1 \n",
392
+ "... ... ... ... \n",
393
+ "287065 16.983333 660.0 1 \n",
394
+ "287066 16.983333 660.0 1 \n",
395
+ "287067 15.983333 600.0 1 \n",
396
+ "287068 15.983333 540.0 1 \n",
397
+ "287069 16.983333 660.0 1 \n",
398
+ "\n",
399
+ "[1464 rows x 22 columns]"
400
+ ]
401
+ },
402
+ "execution_count": 14,
403
+ "metadata": {},
404
+ "output_type": "execute_result"
405
+ }
406
+ ],
407
+ "source": [
408
+ "df"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": 17,
414
+ "id": "5ea34c2b-fc24-4cfa-ad46-f369bea42364",
415
+ "metadata": {},
416
+ "outputs": [
417
+ {
418
+ "data": {
419
+ "text/plain": [
420
+ "Timestamp('2023-06-26 00:00:00')"
421
+ ]
422
+ },
423
+ "execution_count": 17,
424
+ "metadata": {},
425
+ "output_type": "execute_result"
426
+ }
427
+ ],
428
+ "source": [
429
+ "max(df[\"Date\"])"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": 20,
435
+ "id": "a024d9e3-e018-43fe-9bb0-20b9d9e91b53",
436
+ "metadata": {},
437
+ "outputs": [
438
+ {
439
+ "data": {
440
+ "text/plain": [
441
+ "datetime.date(2023, 5, 27)"
442
+ ]
443
+ },
444
+ "execution_count": 20,
445
+ "metadata": {},
446
+ "output_type": "execute_result"
447
+ }
448
+ ],
449
+ "source": [
450
+ "datetime.date(2023, 6, 26) - datetime.timedelta(days = 30)"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": 21,
456
+ "id": "52e0a2c8-5b53-42f4-91c9-3ca7d6a8d356",
457
+ "metadata": {},
458
+ "outputs": [
459
+ {
460
+ "data": {
461
+ "text/plain": [
462
+ "Index(['sales', 'DayInYear', 'time_idx', 'Wahl', 'Baustelle', 'MontagLangesWE',\n",
463
+ " 'FreitagLangesWE', 'nosale', 'holiday', 'AufSommerzeit',\n",
464
+ " 'AufWinterzeit', 'Group', 'Branch', 'Weekday', 'Date',\n",
465
+ " 'MTXWTH_Day_precip', 'MTXWTH_Temp_max', 'MTXWTH_Temp_min', 'Start',\n",
466
+ " 'End', 'ShiftLength', 'weight'],\n",
467
+ " dtype='object')"
468
+ ]
469
+ },
470
+ "execution_count": 21,
471
+ "metadata": {},
472
+ "output_type": "execute_result"
473
+ }
474
+ ],
475
+ "source": [
476
+ "df.columns"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": null,
482
+ "id": "75415b44-41ec-4893-8452-939766ebaabc",
483
+ "metadata": {},
484
+ "outputs": [],
485
+ "source": []
486
+ }
487
+ ],
488
+ "metadata": {
489
+ "kernelspec": {
490
+ "display_name": "Python 3 (ipykernel)",
491
+ "language": "python",
492
+ "name": "python3"
493
+ },
494
+ "language_info": {
495
+ "codemirror_mode": {
496
+ "name": "ipython",
497
+ "version": 3
498
+ },
499
+ "file_extension": ".py",
500
+ "mimetype": "text/x-python",
501
+ "name": "python",
502
+ "nbconvert_exporter": "python",
503
+ "pygments_lexer": "ipython3",
504
+ "version": "3.10.11"
505
+ }
506
+ },
507
+ "nbformat": 4,
508
+ "nbformat_minor": 5
509
+ }
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Imports
2
+ import pickle
3
+ import warnings
4
+ import streamlit as st
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ import datetime
11
+
12
+ # import torch
13
+ from torch.distributions import Normal
14
+ from pytorch_forecasting import (
15
+ TimeSeriesDataSet,
16
+ TemporalFusionTransformer,
17
+ )
18
+
19
+ ## Functions
20
+ def raw_preds_to_df(raw,quantiles = None):
21
+ """
22
+ raw is output of model.predict with return_index=True
23
+ quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles
24
+ in the output, time_idx is the first prediction time index (one step after knowledge cutoff)
25
+ pred_idx the index of the predicted date i.e. time_idx + h - 1
26
+ """
27
+ index = raw[2]
28
+ preds = raw[0].prediction
29
+ dec_len = preds.shape[1]
30
+ n_quantiles = preds.shape[-1]
31
+ preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)
32
+ preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))
33
+ preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))
34
+ preds_df = preds_df.assign(pred=preds.flatten().cpu().numpy())
35
+ if quantiles is not None:
36
+ preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})
37
+
38
+ preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1
39
+ return preds_df
40
+
41
+ def prepare_dataset(parameters, df, rain, temperature, datepicker):
42
+ if rain != "Default":
43
+ df["MTXWTH_Day_precip"] = rain_mapping[rain]
44
+
45
+ df["MTXWTH_Temp_min"] = df["MTXWTH_Temp_min"] + temperature
46
+ df["MTXWTH_Temp_max"] = df["MTXWTH_Temp_max"] + temperature
47
+
48
+ lowerbound = datepicker - datetime.timedelta(days = 35)
49
+ upperbound = datepicker + datetime.timedelta(days = 30)
50
+
51
+ df = df.loc[(df["Date"]>lowerbound) & (df["Date"]<=upperbound)]
52
+
53
+ df = TimeSeriesDataSet.from_parameters(parameters, df)
54
+ return df.to_dataloader(train=False, batch_size=256,num_workers = 0)
55
+
56
+ def predict(model, dataloader):
57
+ return model.predict(dataloader, mode="raw", return_x=True, return_index=True)
58
+
59
+ ## Initiate Data
60
+ with open('data/parameters.pkl', 'rb') as f:
61
+ parameters = pickle.load(f)
62
+ model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt')
63
+
64
+ df = pd.read_pickle('data/test_data.pkl')
65
+ df = df.loc[(df["Branch"] == 15) & (df["Group"].isin(["6","7","4","1"]))]
66
+
67
+ rain_mapping = {
68
+ "Yes" : 1,
69
+ "No" : , 0
70
+ }
71
+
72
+ # Start App
73
+ st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting")
74
+
75
+ st.markdown(body = """
76
+ ### Abstract
77
+ Multi-horizon forecasting often contains a complex mix of inputs – including
78
+ static (i.e. time-invariant) covariates, known future inputs, and other exogenous
79
+ time series that are only observed in the past – without any prior information
80
+ on how they interact with the target. Several deep learning methods have been
81
+ proposed, but they are typically ‘black-box’ models which do not shed light on
82
+ how they use the full range of inputs present in practical scenarios. In this pa-
83
+ per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-
84
+ based architecture which combines high-performance multi-horizon forecasting
85
+ with interpretable insights into temporal dynamics. To learn temporal rela-
86
+ tionships at different scales, TFT uses recurrent layers for local processing and
87
+ interpretable self-attention layers for long-term dependencies. TFT utilizes spe-
88
+ cialized components to select relevant features and a series of gating layers to
89
+ suppress unnecessary components, enabling high performance in a wide range of
90
+ scenarios. On a variety of real-world datasets, we demonstrate significant per-
91
+ formance improvements over existing benchmarks, and showcase three practical
92
+ interpretability use cases of TFT.
93
+ """)
94
+
95
+ rain = st.radio("Rain Indicator", ('Default', 'Yes', 'No'))
96
+
97
+ temperature = st.slider('Change in Temperature', min_value=-10, max_value=+10, value=0, step=0.25)
98
+
99
+ datepicker = st.date_input("Start of Forecast", datetime.date(2022, 12, 24), min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30))
100
+
101
+ arr = np.random.normal(1, 1, size=100)
102
+ fig, ax = plt.subplots()
103
+ ax.hist(arr, bins=20)
104
+
105
+ st.pyplot(fig)
106
+
107
+ st.button("Forecast Sales", type="primary") #on_click=None,
108
+
109
+ # %%
110
+ preds = raw_preds_to_df(out, quantiles = None)
111
+
112
+ preds = preds.merge(data_selected[['time_idx','Group','Branch','sales','weight','Date','MTXWTH_Day_precip','MTXWTH_Temp_max','MTXWTH_Temp_min']],how='left',left_on=['pred_idx','Group','Branch'],right_on=['time_idx','Group','Branch'])
113
+ preds.rename(columns={'time_idx_x':'time_idx'},inplace=True)
114
+ preds.drop(columns=['time_idx_y'],inplace=True)
115
+
116
+
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/parameters.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06caaede2baeaa36c308e46ee74a2898141161193d0426c577e3f7029104db10
3
+ size 17761
data/test_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6d1cf5ab9ad31de916030c795598d13ba388f95bbde2a3b295088666fb65ac7
3
+ size 31347323
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/tft_check.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6694f37ecd5da5795eb1b0320fa96dda374fe331b05d9d5e2d0a49001fc2f9ed
3
+ size 5176944
requirements.txt ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.3.1
4
+ alembic==1.9.2
5
+ asttokens==2.2.1
6
+ async-timeout==4.0.2
7
+ attrs==22.2.0
8
+ autopage==0.5.1
9
+ backcall==0.2.0
10
+ cachetools==5.2.1
11
+ certifi==2022.12.7
12
+ charset-normalizer==2.1.1
13
+ cliff==4.1.0
14
+ cmaes==0.9.1
15
+ cmd2==2.4.2
16
+ colorlog==6.7.0
17
+ comm==0.1.2
18
+ contourpy==1.0.7
19
+ cycler==0.11.0
20
+ debugpy==1.6.6
21
+ decorator==5.1.1
22
+ executing==1.2.0
23
+ fonttools==4.38.0
24
+ frozenlist==1.3.3
25
+ fsspec==2022.11.0
26
+ future==0.18.3
27
+ google-auth==2.16.0
28
+ google-auth-oauthlib==0.4.6
29
+ greenlet==2.0.1
30
+ #grpcio==1.51.1
31
+ idna==3.4
32
+ importlib-metadata==6.0.0
33
+ importlib-resources==5.10.2
34
+ ipykernel==6.21.2
35
+ ipython==8.10.0
36
+ jedi==0.18.2
37
+ joblib==1.2.0
38
+ jupyter_client==8.0.3
39
+ jupyter_core==5.2.0
40
+ kiwisolver==1.4.4
41
+ lightning-utilities==0.5.0
42
+ lxml==4.9.2
43
+ Mako==1.2.4
44
+ Markdown==3.4.1
45
+ MarkupSafe==2.1.1
46
+ matplotlib==3.6.3
47
+ matplotlib-inline==0.1.6
48
+ multidict==6.0.4
49
+ nest-asyncio==1.5.6
50
+ numpy==1.23.5
51
+ oauthlib==3.2.2
52
+ optuna==2.10.1
53
+ packaging==23.0
54
+ pandas==1.5.2
55
+ parso==0.8.3
56
+ patsy==0.5.3
57
+ pbr==5.11.1
58
+ pexpect==4.8.0
59
+ pickleshare==0.7.5
60
+ Pillow==9.4.0
61
+ platformdirs==3.0.0
62
+ prettytable==3.6.0
63
+ prompt-toolkit==3.0.37
64
+ protobuf==3.20.1
65
+ psutil==5.9.4
66
+ ptyprocess==0.7.0
67
+ pure-eval==0.2.2
68
+ pyasn1==0.4.8
69
+ pyasn1-modules==0.2.8
70
+ pyDeprecate==0.3.1
71
+ Pygments==2.14.0
72
+ pyparsing==3.0.9
73
+ pyperclip==1.8.2
74
+ python-dateutil==2.8.2
75
+ pytorch-forecasting==0.10.3
76
+ pytorch-lightning==1.9.0
77
+ pytz==2022.7.1
78
+ PyYAML==6.0
79
+ pyzmq==25.0.0
80
+ requests==2.28.2
81
+ requests-futures==1.0.0
82
+ requests-oauthlib==1.3.1
83
+ rsa==4.9
84
+ scikit-learn==1.1.3
85
+ scipy==1.10.0
86
+ six==1.16.0
87
+ SQLAlchemy==1.4.46
88
+ stack-data==0.6.2
89
+ statsmodels==0.13.5
90
+ stevedore==4.1.1
91
+ tensorboard==2.11.2
92
+ tensorboard-data-server==0.6.1
93
+ tensorboard-plugin-wit==1.8.1
94
+ tensorboardX==2.5.1
95
+ threadpoolctl==3.1.0
96
+ torch==1.10.2
97
+ torchaudio==0.10.2
98
+ torchmetrics==0.11.0
99
+ torchvision==0.11.3
100
+ tornado==6.2
101
+ tqdm==4.64.1
102
+ traitlets==5.9.0
103
+ typing_extensions==4.4.0
104
+ urllib3==1.26.14
105
+ wcwidth==0.2.6
106
+ Werkzeug==2.2.2
107
+ yahooquery==2.3.1
108
+ yarl==1.8.2
109
+ zipp==3.11.0