diff --git "a/Untitled.ipynb" "b/Untitled.ipynb"
deleted file mode 100644--- "a/Untitled.ipynb"
+++ /dev/null
@@ -1,1369 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 26,
- "id": "702c8313-4ccd-4478-a648-1087978f8af5",
- "metadata": {},
- "outputs": [],
- "source": [
- "import pickle\n",
- "import pandas as pd\n",
- "import matplotlib.pyplot as plt"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "id": "06393522-6a98-46d1-98ec-adc8a889be25",
- "metadata": {},
- "outputs": [],
- "source": [
- "df = pd.read_pickle('data/test_data.pkl')\n",
- "df = df.loc[(df[\"Branch\"] == \"15\") & (df[\"Group\"].isin([\"6\",\"7\",\"4\",\"1\"]))]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "id": "c67e2a79-d310-4832-a8e5-f16fecb6b9be",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " sales | \n",
- " DayInYear | \n",
- " time_idx | \n",
- " Wahl | \n",
- " Baustelle | \n",
- " MontagLangesWE | \n",
- " FreitagLangesWE | \n",
- " nosale | \n",
- " holiday | \n",
- " AufSommerzeit | \n",
- " ... | \n",
- " Branch | \n",
- " Weekday | \n",
- " Date | \n",
- " MTXWTH_Day_precip | \n",
- " MTXWTH_Temp_max | \n",
- " MTXWTH_Temp_min | \n",
- " Start | \n",
- " End | \n",
- " ShiftLength | \n",
- " weight | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 270300 | \n",
- " 1600.9030 | \n",
- " 177 | \n",
- " 2369 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 6 | \n",
- " 2022-06-26 | \n",
- " 0.0 | \n",
- " 28.52 | \n",
- " 17.47 | \n",
- " 7.0 | \n",
- " 10.983333 | \n",
- " 240.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 270301 | \n",
- " 1811.1958 | \n",
- " 178 | \n",
- " 2370 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 0 | \n",
- " 2022-06-27 | \n",
- " 0.0 | \n",
- " 25.75 | \n",
- " 16.70 | \n",
- " 6.0 | \n",
- " 13.983333 | \n",
- " 480.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 270302 | \n",
- " 1784.2916 | \n",
- " 179 | \n",
- " 2371 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 1 | \n",
- " 2022-06-28 | \n",
- " 0.0 | \n",
- " 23.57 | \n",
- " 14.17 | \n",
- " 6.0 | \n",
- " 13.983333 | \n",
- " 480.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 270303 | \n",
- " 1757.3488 | \n",
- " 180 | \n",
- " 2372 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 2 | \n",
- " 2022-06-29 | \n",
- " 0.0 | \n",
- " 26.81 | \n",
- " 13.09 | \n",
- " 6.0 | \n",
- " 13.983333 | \n",
- " 480.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 270304 | \n",
- " 1741.0982 | \n",
- " 181 | \n",
- " 2373 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 3 | \n",
- " 2022-06-30 | \n",
- " 0.0 | \n",
- " 27.26 | \n",
- " 15.00 | \n",
- " 6.0 | \n",
- " 13.983333 | \n",
- " 480.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- "
\n",
- " \n",
- " 287065 | \n",
- " 1643.1700 | \n",
- " 173 | \n",
- " 2730 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 3 | \n",
- " 2023-06-22 | \n",
- " 0.0 | \n",
- " 26.93 | \n",
- " 13.06 | \n",
- " 6.0 | \n",
- " 16.983333 | \n",
- " 660.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 287066 | \n",
- " 1597.3518 | \n",
- " 174 | \n",
- " 2731 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 4 | \n",
- " 2023-06-23 | \n",
- " 1.0 | \n",
- " 23.99 | \n",
- " 15.98 | \n",
- " 6.0 | \n",
- " 16.983333 | \n",
- " 660.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 287067 | \n",
- " 1683.6228 | \n",
- " 175 | \n",
- " 2732 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 5 | \n",
- " 2023-06-24 | \n",
- " 0.0 | \n",
- " 25.99 | \n",
- " 12.04 | \n",
- " 6.0 | \n",
- " 15.983333 | \n",
- " 600.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 287068 | \n",
- " 1785.2180 | \n",
- " 176 | \n",
- " 2733 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 6 | \n",
- " 2023-06-25 | \n",
- " 0.0 | \n",
- " 28.99 | \n",
- " 15.02 | \n",
- " 7.0 | \n",
- " 15.983333 | \n",
- " 540.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 287069 | \n",
- " 1589.9020 | \n",
- " 177 | \n",
- " 2734 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0.0 | \n",
- " 0 | \n",
- " none | \n",
- " 0.0 | \n",
- " ... | \n",
- " 15 | \n",
- " 0 | \n",
- " 2023-06-26 | \n",
- " 0.0 | \n",
- " 27.96 | \n",
- " 17.01 | \n",
- " 6.0 | \n",
- " 16.983333 | \n",
- " 660.0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- "
\n",
- "
1464 rows × 22 columns
\n",
- "
"
- ],
- "text/plain": [
- " sales DayInYear time_idx Wahl Baustelle MontagLangesWE \\\n",
- "270300 1600.9030 177 2369 0.0 0.0 0.0 \n",
- "270301 1811.1958 178 2370 0.0 0.0 0.0 \n",
- "270302 1784.2916 179 2371 0.0 0.0 0.0 \n",
- "270303 1757.3488 180 2372 0.0 0.0 0.0 \n",
- "270304 1741.0982 181 2373 0.0 0.0 0.0 \n",
- "... ... ... ... ... ... ... \n",
- "287065 1643.1700 173 2730 0.0 0.0 0.0 \n",
- "287066 1597.3518 174 2731 0.0 0.0 0.0 \n",
- "287067 1683.6228 175 2732 0.0 0.0 0.0 \n",
- "287068 1785.2180 176 2733 0.0 0.0 0.0 \n",
- "287069 1589.9020 177 2734 0.0 0.0 0.0 \n",
- "\n",
- " FreitagLangesWE nosale holiday AufSommerzeit ... Branch Weekday \\\n",
- "270300 0.0 0 none 0.0 ... 15 6 \n",
- "270301 0.0 0 none 0.0 ... 15 0 \n",
- "270302 0.0 0 none 0.0 ... 15 1 \n",
- "270303 0.0 0 none 0.0 ... 15 2 \n",
- "270304 0.0 0 none 0.0 ... 15 3 \n",
- "... ... ... ... ... ... ... ... \n",
- "287065 0.0 0 none 0.0 ... 15 3 \n",
- "287066 0.0 0 none 0.0 ... 15 4 \n",
- "287067 0.0 0 none 0.0 ... 15 5 \n",
- "287068 0.0 0 none 0.0 ... 15 6 \n",
- "287069 0.0 0 none 0.0 ... 15 0 \n",
- "\n",
- " Date MTXWTH_Day_precip MTXWTH_Temp_max MTXWTH_Temp_min Start \\\n",
- "270300 2022-06-26 0.0 28.52 17.47 7.0 \n",
- "270301 2022-06-27 0.0 25.75 16.70 6.0 \n",
- "270302 2022-06-28 0.0 23.57 14.17 6.0 \n",
- "270303 2022-06-29 0.0 26.81 13.09 6.0 \n",
- "270304 2022-06-30 0.0 27.26 15.00 6.0 \n",
- "... ... ... ... ... ... \n",
- "287065 2023-06-22 0.0 26.93 13.06 6.0 \n",
- "287066 2023-06-23 1.0 23.99 15.98 6.0 \n",
- "287067 2023-06-24 0.0 25.99 12.04 6.0 \n",
- "287068 2023-06-25 0.0 28.99 15.02 7.0 \n",
- "287069 2023-06-26 0.0 27.96 17.01 6.0 \n",
- "\n",
- " End ShiftLength weight \n",
- "270300 10.983333 240.0 1 \n",
- "270301 13.983333 480.0 1 \n",
- "270302 13.983333 480.0 1 \n",
- "270303 13.983333 480.0 1 \n",
- "270304 13.983333 480.0 1 \n",
- "... ... ... ... \n",
- "287065 16.983333 660.0 1 \n",
- "287066 16.983333 660.0 1 \n",
- "287067 15.983333 600.0 1 \n",
- "287068 15.983333 540.0 1 \n",
- "287069 16.983333 660.0 1 \n",
- "\n",
- "[1464 rows x 22 columns]"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "id": "20cb91bd-f9e7-4a96-9da0-58ccd678c2ca",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "270300 1600.9030\n",
- "270301 1811.1958\n",
- "270302 1784.2916\n",
- "270303 1757.3488\n",
- "270304 1741.0982\n",
- " ... \n",
- "270661 1885.6552\n",
- "270662 1974.7440\n",
- "270663 1738.3962\n",
- "270664 1741.8702\n",
- "270665 1973.2386\n",
- "Name: sales, Length: 366, dtype: float64"
- ]
- },
- "execution_count": 18,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df.loc[df['Group'] == '1', \"sales\"]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "id": "8923a52d-c601-45a5-af1f-ef245d0ed7be",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "278502 1585.0384\n",
- "278503 0.0000\n",
- "278504 0.0000\n",
- "278505 1582.0276\n",
- "278506 1521.5414\n",
- " ... \n",
- "278863 1672.7762\n",
- "278864 1598.2010\n",
- "278865 1683.6228\n",
- "278866 1660.6944\n",
- "278867 0.0000\n",
- "Name: sales, Length: 366, dtype: float64"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df.loc[df['Group'] == '4', \"sales\"]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "id": "41cf50d2-7144-49de-9e6e-f155b0b74a1f",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "283970 1580.6766\n",
- "283971 2168.6318\n",
- "283972 2034.7284\n",
- "283973 2147.1702\n",
- "283974 2364.6812\n",
- " ... \n",
- "284331 2212.5200\n",
- "284332 2160.4100\n",
- "284333 2113.0478\n",
- "284334 2016.5864\n",
- "284335 2106.6402\n",
- "Name: sales, Length: 366, dtype: float64"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df.loc[df['Group'] == '6', \"sales\"]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "id": "6e4ed391-9044-41c8-8334-28d59b104c74",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Timestamp('2022-06-26 00:00:00')"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df['Date'].min()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "id": "37af6fed-9b27-4d7e-a7e6-1f787ebecd5c",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "0.0"
- ]
- },
- "execution_count": 23,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df['sales'].min()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "id": "4081b0fd-db08-492c-bf31-a9727b3cbcc2",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "2498.546"
- ]
- },
- "execution_count": 24,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df['sales'].max()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "id": "2971c877-6a2f-4410-9070-9de7b04c8026",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "",
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig, axs = plt.subplots(2, 2, figsize=(8, 6))\n",
- "\n",
- "# Plot scatter plots for each group\n",
- "axs[0, 0].scatter(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='red', marker='o')\n",
- "axs[0, 0].set_title('Article Group 1')\n",
- "\n",
- "axs[0, 1].scatter(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='blue', marker='o')\n",
- "axs[0, 1].set_title('Article Group 2')\n",
- "\n",
- "axs[1, 0].scatter(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='green', marker='o')\n",
- "axs[1, 0].set_title('Article Group 3')\n",
- "\n",
- "axs[1, 1].scatter(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='yellow', marker='o')\n",
- "axs[1, 1].set_title('Article Group 4')\n",
- "\n",
- "# Adjust spacing between subplots\n",
- "plt.tight_layout()\n",
- "\n",
- "for ax in axs.flat:\n",
- " ax.set_xlim(df['Date'].min(), df['Date'].max())\n",
- " ax.set_ylim(df['sales'].min(), df['sales'].max())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "c0dfb883-aa55-45b1-81ea-12afc089d5fd",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/torchvision/io/image.py:11: UserWarning: Failed to load image Python extension: dlopen(/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/torchvision/image.so, 0x0006): Library not loaded: @rpath/libpng16.16.dylib\n",
- " Referenced from: <5F6B6919-410D-397C-98F2-12C5934F9DBE> /Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/torchvision/image.so\n",
- " Reason: tried: '/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/usr/lib/libpng16.16.dylib' (no such file, not in dyld cache)\n",
- " warn(f\"Failed to load image Python extension: {e}\")\n"
- ]
- }
- ],
- "source": [
- "## Imports\n",
- "import pickle\n",
- "import warnings\n",
- "#import streamlit as st\n",
- "from pathlib import Path\n",
- "\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "import matplotlib.pyplot as plt\n",
- "import datetime\n",
- "\n",
- "import torch\n",
- "from torch.distributions import Normal\n",
- "from pytorch_forecasting import (\n",
- " TimeSeriesDataSet,\n",
- " TemporalFusionTransformer,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "720b4e43-5b2c-489f-9b61-779d618b5f45",
- "metadata": {},
- "outputs": [],
- "source": [
- "def raw_preds_to_df(raw,quantiles = None):\n",
- " \"\"\"\n",
- " raw is output of model.predict with return_index=True\n",
- " quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles\n",
- " in the output, time_idx is the first prediction time index (one step after knowledge cutoff)\n",
- " pred_idx the index of the predicted date i.e. time_idx + h - 1\n",
- " \"\"\"\n",
- " index = raw[2]\n",
- " preds = raw[0].prediction\n",
- " dec_len = preds.shape[1]\n",
- " n_quantiles = preds.shape[-1]\n",
- " preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns)\n",
- " preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles)))\n",
- " preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles))\n",
- " preds_df = preds_df.assign(pred=preds.flatten().numpy())\n",
- " if quantiles is not None:\n",
- " preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)})\n",
- "\n",
- " preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1\n",
- " return preds_df"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "7a0abc6e-5310-4e86-b622-a27c757a05f4",
- "metadata": {},
- "outputs": [],
- "source": [
- "def prepare_dataset(parameters, df, rain = \"Default\", temperature = 0.0, datepicker = datetime.date(2022, 10, 24), mapping = {\n",
- " \"Yes\" : 1,\n",
- " \"No\" : 0\n",
- "}):\n",
- " if rain != \"Default\":\n",
- " df[\"MTXWTH_Day_precip\"] = rain_mapping[rain]\n",
- " \n",
- " df[\"MTXWTH_Temp_min\"] = df[\"MTXWTH_Temp_min\"] + temperature\n",
- " df[\"MTXWTH_Temp_max\"] = df[\"MTXWTH_Temp_max\"] + temperature\n",
- "\n",
- " lowerbound = datepicker - datetime.timedelta(days = 35) \n",
- " upperbound = datepicker + datetime.timedelta(days = 30) \n",
- "\n",
- " df = df.loc[(df[\"Date\"].dt.date>lowerbound) & (df[\"Date\"].dt.date<=upperbound)]\n",
- " print(df)\n",
- " df = TimeSeriesDataSet.from_parameters(parameters, df)\n",
- " return df.to_dataloader(train=False, batch_size=256,num_workers = 0)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "62785ac5-d144-486a-aa15-790e8624fcf0",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " sales DayInYear time_idx Wahl Baustelle MontagLangesWE \\\n",
- "270386 1920.2408 263 2455 0.0 0.0 0.0 \n",
- "270387 1814.0908 264 2456 0.0 0.0 0.0 \n",
- "270388 1749.4744 265 2457 0.0 0.0 0.0 \n",
- "270389 1878.0510 266 2458 0.0 0.0 0.0 \n",
- "270390 1853.9646 267 2459 0.0 0.0 0.0 \n",
- "... ... ... ... ... ... ... \n",
- "286850 1693.3114 323 2515 0.0 0.0 0.0 \n",
- "286851 1898.4318 324 2516 0.0 0.0 0.0 \n",
- "286852 1581.6030 325 2517 0.0 0.0 0.0 \n",
- "286853 1569.0580 326 2518 0.0 0.0 0.0 \n",
- "286854 0.0000 327 2519 nan nan nan \n",
- "\n",
- " FreitagLangesWE nosale holiday AufSommerzeit ... \\\n",
- "270386 0.0 0 none 0.0 ... \n",
- "270387 0.0 0 none 0.0 ... \n",
- "270388 0.0 0 none 0.0 ... \n",
- "270389 0.0 0 none 0.0 ... \n",
- "270390 0.0 0 none 0.0 ... \n",
- "... ... ... ... ... ... \n",
- "286850 0.0 0 none 0.0 ... \n",
- "286851 0.0 0 NotCondensed_Totensonntag 0.0 ... \n",
- "286852 0.0 0 none 0.0 ... \n",
- "286853 0.0 0 none 0.0 ... \n",
- "286854 nan 1 none 0.0 ... \n",
- "\n",
- " Branch Weekday Date MTXWTH_Day_precip MTXWTH_Temp_max \\\n",
- "270386 15 1 2022-09-20 0.0 16.95 \n",
- "270387 15 2 2022-09-21 0.0 17.99 \n",
- "270388 15 3 2022-09-22 0.0 17.96 \n",
- "270389 15 4 2022-09-23 0.0 17.75 \n",
- "270390 15 5 2022-09-24 0.0 14.59 \n",
- "... ... ... ... ... ... \n",
- "286850 15 5 2022-11-19 0.0 4.47 \n",
- "286851 15 6 2022-11-20 1.0 2.53 \n",
- "286852 15 0 2022-11-21 0.0 1.44 \n",
- "286853 15 1 2022-11-22 0.0 1.99 \n",
- "286854 15 2 2022-11-23 0.0 4.82 \n",
- "\n",
- " MTXWTH_Temp_min Start End ShiftLength weight \n",
- "270386 7.05 6.0 13.983333 480.0 1 \n",
- "270387 6.05 6.0 13.983333 480.0 1 \n",
- "270388 10.03 6.0 13.983333 480.0 1 \n",
- "270389 8.39 6.0 13.983333 480.0 1 \n",
- "270390 11.31 6.0 11.983333 360.0 1 \n",
- "... ... ... ... ... ... \n",
- "286850 -4.95 6.0 11.983333 360.0 1 \n",
- "286851 -2.96 7.0 10.983333 240.0 1 \n",
- "286852 -3.72 6.0 13.983333 480.0 1 \n",
- "286853 -2.58 6.0 13.983333 480.0 1 \n",
- "286854 2.40 7.0 17.983333 660.0 0 \n",
- "\n",
- "[260 rows x 22 columns]\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator StandardScaler from version 1.2.2 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
- "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
- " warnings.warn(\n"
- ]
- }
- ],
- "source": [
- "with open('data/parameters.pkl', 'rb') as f:\n",
- " parameters = pickle.load(f)\n",
- "df = pd.read_pickle('data/test_data.pkl')\n",
- "df = df.loc[(df[\"Branch\"] == \"15\") & (df[\"Group\"].isin([\"6\",\"7\",\"4\",\"1\"]))]\n",
- "_dataloader = prepare_dataset(parameters, df.copy())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "fda01d15-fd6a-4b87-abfd-7d9011af53e4",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.\n",
- " rank_zero_warn(\n",
- "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.\n",
- " rank_zero_warn(\n"
- ]
- }
- ],
- "source": [
- "_model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check.ckpt', map_location=torch.device('cpu'))\n",
- "out = _model.predict(_dataloader, mode=\"raw\", return_x=True, return_index=True)\n",
- "preds = raw_preds_to_df(out, quantiles = None)\n",
- "\n",
- "#preds = preds[[\"pred_idx\", \"Group\", \"pred\"]]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "bdcabe69-bc3b-46c7-a17f-138ec756c1c2",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " time_idx | \n",
- " Group | \n",
- " Branch | \n",
- " h | \n",
- " q | \n",
- " pred | \n",
- " pred_idx | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 2490 | \n",
- " 1 | \n",
- " 15 | \n",
- " 1 | \n",
- " 0 | \n",
- " 1826.949707 | \n",
- " 2490 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 2490 | \n",
- " 1 | \n",
- " 15 | \n",
- " 2 | \n",
- " 0 | \n",
- " 1856.215088 | \n",
- " 2491 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 2490 | \n",
- " 1 | \n",
- " 15 | \n",
- " 3 | \n",
- " 0 | \n",
- " 1871.929688 | \n",
- " 2492 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 2490 | \n",
- " 1 | \n",
- " 15 | \n",
- " 4 | \n",
- " 0 | \n",
- " 1866.095825 | \n",
- " 2493 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 2490 | \n",
- " 1 | \n",
- " 15 | \n",
- " 5 | \n",
- " 0 | \n",
- " 1787.610840 | \n",
- " 2494 | \n",
- "
\n",
- " \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- "
\n",
- " \n",
- " 115 | \n",
- " 2490 | \n",
- " 7 | \n",
- " 15 | \n",
- " 26 | \n",
- " 0 | \n",
- " 1655.261475 | \n",
- " 2515 | \n",
- "
\n",
- " \n",
- " 116 | \n",
- " 2490 | \n",
- " 7 | \n",
- " 15 | \n",
- " 27 | \n",
- " 0 | \n",
- " 1794.728027 | \n",
- " 2516 | \n",
- "
\n",
- " \n",
- " 117 | \n",
- " 2490 | \n",
- " 7 | \n",
- " 15 | \n",
- " 28 | \n",
- " 0 | \n",
- " 1600.507812 | \n",
- " 2517 | \n",
- "
\n",
- " \n",
- " 118 | \n",
- " 2490 | \n",
- " 7 | \n",
- " 15 | \n",
- " 29 | \n",
- " 0 | \n",
- " 1595.128540 | \n",
- " 2518 | \n",
- "
\n",
- " \n",
- " 119 | \n",
- " 2490 | \n",
- " 7 | \n",
- " 15 | \n",
- " 30 | \n",
- " 0 | \n",
- " 1557.557007 | \n",
- " 2519 | \n",
- "
\n",
- " \n",
- "
\n",
- "
120 rows × 7 columns
\n",
- "
"
- ],
- "text/plain": [
- " time_idx Group Branch h q pred pred_idx\n",
- "0 2490 1 15 1 0 1826.949707 2490\n",
- "1 2490 1 15 2 0 1856.215088 2491\n",
- "2 2490 1 15 3 0 1871.929688 2492\n",
- "3 2490 1 15 4 0 1866.095825 2493\n",
- "4 2490 1 15 5 0 1787.610840 2494\n",
- ".. ... ... ... .. .. ... ...\n",
- "115 2490 7 15 26 0 1655.261475 2515\n",
- "116 2490 7 15 27 0 1794.728027 2516\n",
- "117 2490 7 15 28 0 1600.507812 2517\n",
- "118 2490 7 15 29 0 1595.128540 2518\n",
- "119 2490 7 15 30 0 1557.557007 2519\n",
- "\n",
- "[120 rows x 7 columns]"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "preds"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "ab7c16c2-6fcc-40bd-a866-1fecd122dfc1",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " time_idx | \n",
- " sales | \n",
- " Group | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 270300 | \n",
- " 2369 | \n",
- " 1600.9030 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 270301 | \n",
- " 2370 | \n",
- " 1811.1958 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 270302 | \n",
- " 2371 | \n",
- " 1784.2916 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 270303 | \n",
- " 2372 | \n",
- " 1757.3488 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 270304 | \n",
- " 2373 | \n",
- " 1741.0982 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- "
\n",
- " \n",
- " 287065 | \n",
- " 2730 | \n",
- " 1643.1700 | \n",
- " 7 | \n",
- "
\n",
- " \n",
- " 287066 | \n",
- " 2731 | \n",
- " 1597.3518 | \n",
- " 7 | \n",
- "
\n",
- " \n",
- " 287067 | \n",
- " 2732 | \n",
- " 1683.6228 | \n",
- " 7 | \n",
- "
\n",
- " \n",
- " 287068 | \n",
- " 2733 | \n",
- " 1785.2180 | \n",
- " 7 | \n",
- "
\n",
- " \n",
- " 287069 | \n",
- " 2734 | \n",
- " 1589.9020 | \n",
- " 7 | \n",
- "
\n",
- " \n",
- "
\n",
- "
1464 rows × 3 columns
\n",
- "
"
- ],
- "text/plain": [
- " time_idx sales Group\n",
- "270300 2369 1600.9030 1\n",
- "270301 2370 1811.1958 1\n",
- "270302 2371 1784.2916 1\n",
- "270303 2372 1757.3488 1\n",
- "270304 2373 1741.0982 1\n",
- "... ... ... ...\n",
- "287065 2730 1643.1700 7\n",
- "287066 2731 1597.3518 7\n",
- "287067 2732 1683.6228 7\n",
- "287068 2733 1785.2180 7\n",
- "287069 2734 1589.9020 7\n",
- "\n",
- "[1464 rows x 3 columns]"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df[[\"time_idx\",\"sales\", \"Group\"]]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "id": "b5812d1f-f42d-4031-b9b7-32fac2bbe425",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " sales DayInYear time_idx Wahl Baustelle MontagLangesWE \\\n",
- "0 1600.9030 177 2369 0.0 0.0 0.0 \n",
- "1 1811.1958 178 2370 0.0 0.0 0.0 \n",
- "2 1784.2916 179 2371 0.0 0.0 0.0 \n",
- "3 1757.3488 180 2372 0.0 0.0 0.0 \n",
- "4 1741.0982 181 2373 0.0 0.0 0.0 \n",
- "... ... ... ... ... ... ... \n",
- "1459 1643.1700 173 2730 0.0 0.0 0.0 \n",
- "1460 1597.3518 174 2731 0.0 0.0 0.0 \n",
- "1461 1683.6228 175 2732 0.0 0.0 0.0 \n",
- "1462 1785.2180 176 2733 0.0 0.0 0.0 \n",
- "1463 1589.9020 177 2734 0.0 0.0 0.0 \n",
- "\n",
- " FreitagLangesWE nosale holiday AufSommerzeit ... Date \\\n",
- "0 0.0 0 none 0.0 ... 2022-06-26 \n",
- "1 0.0 0 none 0.0 ... 2022-06-27 \n",
- "2 0.0 0 none 0.0 ... 2022-06-28 \n",
- "3 0.0 0 none 0.0 ... 2022-06-29 \n",
- "4 0.0 0 none 0.0 ... 2022-06-30 \n",
- "... ... ... ... ... ... ... \n",
- "1459 0.0 0 none 0.0 ... 2023-06-22 \n",
- "1460 0.0 0 none 0.0 ... 2023-06-23 \n",
- "1461 0.0 0 none 0.0 ... 2023-06-24 \n",
- "1462 0.0 0 none 0.0 ... 2023-06-25 \n",
- "1463 0.0 0 none 0.0 ... 2023-06-26 \n",
- "\n",
- " MTXWTH_Day_precip MTXWTH_Temp_max MTXWTH_Temp_min Start End \\\n",
- "0 0.0 28.52 17.47 7.0 10.983333 \n",
- "1 0.0 25.75 16.70 6.0 13.983333 \n",
- "2 0.0 23.57 14.17 6.0 13.983333 \n",
- "3 0.0 26.81 13.09 6.0 13.983333 \n",
- "4 0.0 27.26 15.00 6.0 13.983333 \n",
- "... ... ... ... ... ... \n",
- "1459 0.0 26.93 13.06 6.0 16.983333 \n",
- "1460 1.0 23.99 15.98 6.0 16.983333 \n",
- "1461 0.0 25.99 12.04 6.0 15.983333 \n",
- "1462 0.0 28.99 15.02 7.0 15.983333 \n",
- "1463 0.0 27.96 17.01 6.0 16.983333 \n",
- "\n",
- " ShiftLength weight pred_idx pred \n",
- "0 240.0 1 NaN NaN \n",
- "1 480.0 1 NaN NaN \n",
- "2 480.0 1 NaN NaN \n",
- "3 480.0 1 NaN NaN \n",
- "4 480.0 1 NaN NaN \n",
- "... ... ... ... ... \n",
- "1459 660.0 1 NaN NaN \n",
- "1460 660.0 1 NaN NaN \n",
- "1461 600.0 1 NaN NaN \n",
- "1462 540.0 1 NaN NaN \n",
- "1463 660.0 1 NaN NaN \n",
- "\n",
- "[1344 rows x 24 columns]\n"
- ]
- }
- ],
- "source": [
- "new = pd.merge(df, preds, left_on=[\"time_idx\", \"Group\"], right_on=[\"pred_idx\", \"Group\"], how = \"left\")\n",
- "print(new[new[\"pred\"].isna()])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "f4577535-4d1c-4500-96df-7df83c61011e",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0 1826.949707\n",
- "1 1856.215088\n",
- "2 1871.929688\n",
- "3 1866.095825\n",
- "4 1787.610840\n",
- "5 1659.851440\n",
- "6 1749.331543\n",
- "7 1878.842407\n",
- "8 1868.011841\n",
- "9 1886.265625\n",
- "10 1882.206055\n",
- "11 1794.047852\n",
- "12 1664.293335\n",
- "13 1855.808838\n",
- "14 1866.882690\n",
- "15 1880.703613\n",
- "16 1893.997437\n",
- "17 1886.795654\n",
- "18 1808.130371\n",
- "19 1699.440552\n",
- "20 1854.313721\n",
- "21 1858.726685\n",
- "22 1814.451782\n",
- "23 1863.013184\n",
- "24 1884.281982\n",
- "25 1781.215820\n",
- "26 1679.599243\n",
- "27 1804.118652\n",
- "28 1800.697144\n",
- "29 1794.221802\n",
- "Name: 1, dtype: float32\n",
- "30 1528.776611\n",
- "31 1543.648682\n",
- "32 1544.217407\n",
- "33 1533.632935\n",
- "34 1547.297485\n",
- "35 1568.251465\n",
- "36 1432.989502\n",
- "37 1559.411621\n",
- "38 1555.778076\n",
- "39 1562.257935\n",
- "40 1563.604248\n",
- "41 1542.455200\n",
- "42 1595.332031\n",
- "43 1542.713867\n",
- "44 1494.202271\n",
- "45 1556.931152\n",
- "46 1561.306152\n",
- "47 1575.157959\n",
- "48 1546.359741\n",
- "49 1665.563232\n",
- "50 1547.409302\n",
- "51 1537.488281\n",
- "52 1533.096191\n",
- "53 1506.796509\n",
- "54 1532.159912\n",
- "55 1509.309326\n",
- "56 1635.444092\n",
- "57 1473.974487\n",
- "58 1466.395752\n",
- "59 1451.778198\n",
- "Name: 4, dtype: float32\n",
- "60 2009.041016\n",
- "61 2052.925537\n",
- "62 2075.127686\n",
- "63 2073.555420\n",
- "64 1894.730835\n",
- "65 1712.648071\n",
- "66 2012.644165\n",
- "67 2078.718262\n",
- "68 2063.913574\n",
- "69 2102.091309\n",
- "70 2095.394043\n",
- "71 1910.086914\n",
- "72 1771.222168\n",
- "73 2049.833252\n",
- "74 2064.734375\n",
- "75 2089.499023\n",
- "76 2110.970947\n",
- "77 2104.520020\n",
- "78 1933.123291\n",
- "79 1808.569946\n",
- "80 2050.945801\n",
- "81 2056.272217\n",
- "82 2002.636841\n",
- "83 2061.285889\n",
- "84 2097.291260\n",
- "85 1900.992188\n",
- "86 1744.634521\n",
- "87 1993.816650\n",
- "88 1990.942261\n",
- "89 2046.452393\n",
- "Name: 6, dtype: float32\n",
- "90 1626.679443\n",
- "91 1618.315674\n",
- "92 1620.280151\n",
- "93 1627.040161\n",
- "94 1673.808228\n",
- "95 1787.030273\n",
- "96 1562.161865\n",
- "97 1642.268433\n",
- "98 1618.766113\n",
- "99 1619.909424\n",
- "100 1619.908691\n",
- "101 1666.063477\n",
- "102 1777.630371\n",
- "103 1620.351807\n",
- "104 1619.798706\n",
- "105 1613.335693\n",
- "106 1611.184326\n",
- "107 1621.388306\n",
- "108 1673.271851\n",
- "109 1810.889648\n",
- "110 1619.840698\n",
- "111 1615.521118\n",
- "112 1616.226562\n",
- "113 1593.204102\n",
- "114 1606.631958\n",
- "115 1655.261475\n",
- "116 1794.728027\n",
- "117 1600.507812\n",
- "118 1595.128540\n",
- "119 1557.557007\n",
- "Name: 7, dtype: float32\n"
- ]
- }
- ],
- "source": [
- "datepicker = datetime.date(2022, 10, 24)\n",
- "def add_dates(group):\n",
- " #group[\"date_imputed\"] = [datepicker + datetime.timedelta(days=x) for x in range(30)]\n",
- " print(group)\n",
- " return group\n",
- "\n",
- "preds[\"date_imputed\"] = preds.groupby(\"Group\").pred.transform(add_dates)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a8af2ca0-b8bb-4e08-8963-d463a7122033",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "TFT_HF",
- "language": "python",
- "name": "tft_hf"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}