TiRex-demo / app.py
Nikita
waiting for the review
76a84b9
raw
history blame
10.3 kB
import io
import pandas as pd
import torch
import plotly.graph_objects as go
from PIL import Image
import numpy as np
import gradio as gr
# ----------------------------
# Helper functions (logic mostly unchanged)
# ----------------------------
torch.manual_seed(42)
_forecast_tensor = torch.load("stocks_data_forecast.pt") # shape = (n_series, pred_len, n_q)
def model_forecast(input_data):
return _forecast_tensor
def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
# Create an interactive Plotly figure
fig = go.Figure()
x_hist = list(range(len(timeseries)))
# Historical data trace
fig.add_trace(go.Scatter(
x=x_hist,
y=timeseries,
mode='lines+markers',
name=f"{timeseries_name} - Given Data",
line=dict(width=2),
))
# Prediction data traces for each quantile
x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions)))
for i in range(quantile_predictions.shape[1]):
fig.add_trace(go.Scatter(
x=x_pred,
y=quantile_predictions[:, i],
mode='lines',
name=f"{timeseries_name} - Quantile {i+1}",
opacity=0.8,
))
fig.update_layout(
title=dict(text=f"Timeseries: {timeseries_name}", x=0.5, font=dict(size=16, family="Arial", color="#000")),
xaxis_title="Time",
yaxis_title="Value",
hovermode='x unified'
)
return fig
def load_table(file_path):
ext = file_path.split(".")[-1].lower()
if ext == "csv":
return pd.read_csv(file_path)
elif ext in ("xls", "xlsx"):
return pd.read_excel(file_path)
elif ext == "parquet":
return pd.read_parquet(file_path)
else:
raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")
def extract_names_and_update(file, preset_filename):
try:
if file is not None:
df = load_table(file.name)
else:
if not preset_filename:
return gr.update(choices=[], value=[]), []
df = load_table(preset_filename)
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
names = df.iloc[:, 0].tolist()
else:
names = [f"Series {i}" for i in range(len(df))]
return gr.update(choices=names, value=names), names
except Exception:
return gr.update(choices=[], value=[]), []
def filter_names(search_term, all_names):
if not all_names:
return gr.update(choices=[], value=[])
if not search_term:
return gr.update(choices=all_names, value=all_names)
lower = search_term.lower()
filtered = [n for n in all_names if lower in str(n).lower()]
return gr.update(choices=filtered, value=filtered)
def check_all(names_list):
return gr.update(value=names_list)
def uncheck_all(_):
return gr.update(value=[])
def display_filtered_forecast(file, preset_filename, selected_names):
try:
# If no file uploaded and no valid preset chosen, return early
if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"):
return None, "No file selected."
# Load data
if file is not None:
df = load_table(file.name)
else:
df = load_table(preset_filename)
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
all_names = df.iloc[:, 0].tolist()
data_only = df.iloc[:, 1:].astype(float)
else:
all_names = [f"Series {i}" for i in range(len(df))]
data_only = df.astype(float)
mask = [name in selected_names for name in all_names]
if not any(mask):
return None, "No timeseries chosen to plot."
filtered_data = data_only.iloc[mask, :].values
filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
out = _forecast_tensor[mask] # slice forecasts to match filtered rows
inp = torch.tensor(filtered_data)
# If multiple series selected, create a subplot for each in a single figure
fig = go.Figure()
for idx in range(inp.shape[0]):
ts = inp[idx].numpy().tolist()
qp = out[idx].numpy()
series_name = filtered_names[idx]
x_hist = list(range(len(ts)))
# Historical data
fig.add_trace(go.Scatter(
x=x_hist,
y=ts,
mode='lines+markers',
name=f"{series_name} - Given Data"
))
# Quantiles
x_pred = list(range(len(ts) - 1, len(ts) - 1 + qp.shape[0]))
for i in range(qp.shape[1]):
fig.add_trace(go.Scatter(
x=x_pred,
y=qp[:, i],
mode='lines',
name=f"{series_name} - Quantile {i+1}",
opacity=0.6
))
fig.update_layout(
title=dict(text="Forecasts for Selected Timeseries", x=0.5, font=dict(size=16, family="Arial", color="#000")),
xaxis_title="Time",
yaxis_title="Value",
hovermode='x unified'
)
return fig, ""
except Exception as e:
return None, f"Error: {e}. Use CSV, XLS, XLSX, or PARQUET."
# ----------------------------
# Gradio layout: two columns + instructions
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("# 📈 TiRex - timeseries forecasting 📊")
gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")
with gr.Row():
# Left column: controls
with gr.Column():
gr.Markdown("## Data Selection")
file_input = gr.File(
label="Upload CSV / XLSX / PARQUET",
file_types=[".csv", ".xls", ".xlsx", ".parquet"]
)
preset_choices = ["-- No preset selected --", "stocks_data_noindex.csv", "stocks_data.csv"]
preset_dropdown = gr.Dropdown(
label="Or choose a preset:",
choices=preset_choices,
value="-- No preset selected --"
)
gr.Markdown("## Search / Filter")
search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
filter_checkbox = gr.CheckboxGroup(
choices=[], value=[], label="Select which timeseries to show"
)
with gr.Row():
check_all_btn = gr.Button("✅ Check All")
uncheck_all_btn = gr.Button("❎ Uncheck All")
plot_button = gr.Button("▶️ Plot Forecasts")
errbox = gr.Textbox(label="Error Message", interactive=False)
with gr.Row():
gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False)
gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False)
# Right column: interactive plot + instructions
with gr.Column():
gr.Markdown("## Forecast Plot")
plot_output = gr.Plot()
# Instruction text below plot
gr.Markdown("## Instructions")
gr.Markdown(
"""
**How to format your data:**
- Your file must be a table (CSV, XLS, XLSX, or Parquet).
- **One row per timeseries.** Each row is treated as a separate series.
- If you want to **name** each series, put the name as the first value in **every** row:
- Example (CSV):
`AAPL, 120.5, 121.0, 119.8, ...`
`AMZN, 3300.0, 3310.5, 3295.2, ...`
- In that case, the first column is not numeric, so it will be used as the series name.
- If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
- Example:
`120.5, 121.0, 119.8, ...`
`3300.0, 3310.5, 3295.2, ...`
- Then every row will be auto-named “Series 0, Series 1, …” in order.
- **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
- The rest of the columns (after the optional name) must be numeric data points for that series.
- You can filter by typing in the search box. Then check or uncheck individual names before plotting.
- Use “Check All” / “Uncheck All” to quickly select or deselect every series.
- Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead).
"""
)
names_state = gr.State([])
# When file or preset changes, update names
file_input.change(
fn=extract_names_and_update,
inputs=[file_input, preset_dropdown],
outputs=[filter_checkbox, names_state]
)
preset_dropdown.change(
fn=extract_names_and_update,
inputs=[file_input, preset_dropdown],
outputs=[filter_checkbox, names_state]
)
# When search term changes, filter names
search_box.change(
fn=filter_names,
inputs=[search_box, names_state],
outputs=[filter_checkbox]
)
# Check All / Uncheck All
check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)
# Plot button
plot_button.click(
fn=display_filtered_forecast,
inputs=[file_input, preset_dropdown, filter_checkbox],
outputs=[plot_output, errbox]
)
demo.launch()
# '''
# 1. Prepared datasets
# 2. Plots of different quiantilies (different colors)
# 3. Filters for plots...
# 4. Different input options
# 5. README.md in there (in UI) (contact us for fine-tuning)
# 6. Requirements for dimensions
# 7. LOGO of NX-AI and xLSTM and tirex
# 8. *Range of prediction length customizable
# 9. *Multivariate data (x_t is vector)
# '''