import streamlit as st import pandas as pd import matplotlib.pyplot as plt from neuralforecast.core import NeuralForecast from neuralforecast.models import NHITS, TimesNet, LSTM, TFT from neuralforecast.losses.pytorch import HuberMQLoss from neuralforecast.utils import AirPassengersDF import time # Paths for saving models nhits_paths = { 'D': './M4/NHITS/daily', 'M': './M4/NHITS/monthly', 'H': './M4/NHITS/hourly', 'W': './M4/NHITS/weekly', 'Y': './M4/NHITS/yearly' } timesnet_paths = { 'D': './M4/TimesNet/daily', 'M': './M4/TimesNet/monthly', 'H': './M4/TimesNet/hourly', 'W': './M4/TimesNet/weekly', 'Y': './M4/TimesNet/yearly' } lstm_paths = { 'D': './M4/LSTM/daily', 'M': './M4/LSTM/monthly', 'H': './M4/LSTM/hourly', 'W': './M4/LSTM/weekly', 'Y': './M4/LSTM/yearly' } tft_paths = { 'D': './M4/TFT/daily', 'M': './M4/TFT/monthly', 'H': './M4/TFT/hourly', 'W': './M4/TFT/weekly', 'Y': './M4/TFT/yearly' } @st.cache_resource def load_model(path, freq): nf = NeuralForecast.load(path=path) return nf nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()} timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()} lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()} tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()} def generate_forecast(model, df): forecast_df = model.predict(df=df) return forecast_df def determine_frequency(df): df['ds'] = pd.to_datetime(df['ds']) df = df.set_index('ds') freq = pd.infer_freq(df.index) return freq def plot_forecasts(forecast_df, train_df, title): fig, ax = plt.subplots(1, 1, figsize=(20, 7)) plot_df = pd.concat([train_df, forecast_df]).set_index('ds') historical_col = 'y' forecast_col = next((col for col in plot_df.columns if 'median' in col), None) lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None) hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None) if forecast_col is None: raise KeyError("No forecast column found in the data.") plot_df[[historical_col, forecast_col]].plot(ax=ax, linewidth=2, label=['Historical', 'Forecast']) if lo_col and hi_col: ax.fill_between( plot_df.index, plot_df[lo_col], plot_df[hi_col], color='blue', alpha=0.3, label='90% Confidence Interval' ) ax.set_title(title, fontsize=22) ax.set_ylabel('Value', fontsize=20) ax.set_xlabel('Timestamp [t]', fontsize=20) ax.legend(prop={'size': 15}) ax.grid() st.pyplot(fig) def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models): if freq == 'D': return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D'] elif freq == 'M': return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M'] elif freq == 'H': return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H'] elif freq in ['W', 'W-SUN']: return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W'] elif freq in ['Y', 'Y-DEC']: return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y'] else: raise ValueError(f"Unsupported frequency: {freq}") def select_model(horizon, model_type, max_steps=200): if model_type == 'NHITS': return NHITS(input_size=5 * horizon, h=horizon, max_steps=max_steps, stack_types=3*['identity'], n_blocks=3*[1], mlp_units=[[256, 256] for _ in range(3)], n_pool_kernel_size=3*[1], batch_size=32, scaler_type='standard', n_freq_downsample=[12, 4, 1], loss=HuberMQLoss(level=[90])) elif model_type == 'TimesNet': return TimesNet(h=horizon, input_size=horizon * 5, hidden_size=16, conv_hidden_size=32, loss=HuberMQLoss(level=[90]), scaler_type='standard', learning_rate=1e-3, max_steps=max_steps, val_check_steps=200, valid_batch_size=64, windows_batch_size=128, inference_windows_batch_size=512) elif model_type == 'LSTM': return LSTM(h=horizon, input_size=horizon * 5, loss=HuberMQLoss(level=[90]), scaler_type='standard', encoder_n_layers=2, encoder_hidden_size=64, context_size=10, decoder_hidden_size=64, decoder_layers=2, max_steps=max_steps) elif model_type == 'TFT': return TFT(h=horizon, input_size=horizon, hidden_size=16, loss=HuberMQLoss(level=[90]), learning_rate=0.005, scaler_type='standard', windows_batch_size=128, max_steps=max_steps, val_check_steps=200, valid_batch_size=64, enable_progress_bar=True) else: raise ValueError(f"Unsupported model type: {model_type}") def forecast_time_series(df, model_type, freq, horizon, max_steps=200): start_time = time.time() # Start timing if freq: df['ds'] = pd.date_range(start='1970-01-01', periods=len(df), freq=freq) else: freq = determine_frequency(df) st.write(f"Determined frequency: {freq}") df['ds'] = pd.to_datetime(df['ds'], errors='coerce') df = df.dropna(subset=['ds']) model = select_model(horizon, model_type, max_steps) forecast_results = {} st.write(f"Generating forecast using {model_type} model...") forecast_results[model_type] = generate_forecast(model, df, freq) for model_name, forecast_df in forecast_results.items(): plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison') end_time = time.time() # End timing time_taken = end_time - start_time st.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds") # Streamlit App st.title("Dynamic and Automatic Time Series Forecasting") # Upload dataset uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) if uploaded_file: df = pd.read_csv(uploaded_file) else: st.warning("Using default data") df = AirPassengersDF.copy() # Model selection and forecasting st.subheader("Transfer Learning Forecasting") model_choice = st.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"]) horizon = st.slider("Forecast horizon", 1, 100, 10) # Determine frequency of data frequency = determine_frequency(df) st.write(f"Detected frequency: {frequency}") # Load pre-trained models nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models) forecast_results = {} start_time = time.time() # Start timing if model_choice == "NHITS": forecast_results['NHITS'] = generate_forecast(nhits_model, df) elif model_choice == "TimesNet": forecast_results['TimesNet'] = generate_forecast(timesnet_model, df) elif model_choice == "LSTM": forecast_results['LSTM'] = generate_forecast(lstm_model, df) elif model_choice == "TFT": forecast_results['TFT'] = generate_forecast(tft_model, df) for model_name, forecast_df in forecast_results.items(): plot_forecasts(forecast_df, df, f'{model_name} Forecast') end_time = time.time() # End timing time_taken = end_time - start_time st.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds") # Dynamic forecasting st.subheader("Dynamic Forecasting") dynamic_model_choice = st.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice") dynamic_horizon = st.slider("Forecast horizon for dynamic forecasting", 1, 100, 10, key="dynamic_horizon") forecast_time_series(df, dynamic_model_choice, frequency, dynamic_horizon)