Spaces:
Running
on
T4
Running
on
T4
import io | |
import pandas as pd | |
import torch | |
import plotly.graph_objects as go | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
# ---------------------------- | |
# Helper functions (logic mostly unchanged) | |
# ---------------------------- | |
torch.manual_seed(42) | |
_forecast_tensor = torch.load("stocks_data_forecast.pt") # shape = (n_series, pred_len, n_q) | |
def model_forecast(input_data): | |
return _forecast_tensor | |
def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name): | |
# Create an interactive Plotly figure | |
fig = go.Figure() | |
x_hist = list(range(len(timeseries))) | |
# Historical data trace | |
fig.add_trace(go.Scatter( | |
x=x_hist, | |
y=timeseries, | |
mode='lines+markers', | |
name=f"{timeseries_name} - Given Data", | |
line=dict(width=2), | |
)) | |
# Prediction data traces for each quantile | |
x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))) | |
for i in range(quantile_predictions.shape[1]): | |
fig.add_trace(go.Scatter( | |
x=x_pred, | |
y=quantile_predictions[:, i], | |
mode='lines', | |
name=f"{timeseries_name} - Quantile {i+1}", | |
opacity=0.8, | |
)) | |
fig.update_layout( | |
title=dict(text=f"Timeseries: {timeseries_name}", x=0.5, font=dict(size=16, family="Arial", color="#000")), | |
xaxis_title="Time", | |
yaxis_title="Value", | |
hovermode='x unified' | |
) | |
return fig | |
def load_table(file_path): | |
ext = file_path.split(".")[-1].lower() | |
if ext == "csv": | |
return pd.read_csv(file_path) | |
elif ext in ("xls", "xlsx"): | |
return pd.read_excel(file_path) | |
elif ext == "parquet": | |
return pd.read_parquet(file_path) | |
else: | |
raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.") | |
def extract_names_and_update(file, preset_filename): | |
try: | |
if file is not None: | |
df = load_table(file.name) | |
else: | |
if not preset_filename: | |
return gr.update(choices=[], value=[]), [] | |
df = load_table(preset_filename) | |
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all(): | |
names = df.iloc[:, 0].tolist() | |
else: | |
names = [f"Series {i}" for i in range(len(df))] | |
return gr.update(choices=names, value=names), names | |
except Exception: | |
return gr.update(choices=[], value=[]), [] | |
def filter_names(search_term, all_names): | |
if not all_names: | |
return gr.update(choices=[], value=[]) | |
if not search_term: | |
return gr.update(choices=all_names, value=all_names) | |
lower = search_term.lower() | |
filtered = [n for n in all_names if lower in str(n).lower()] | |
return gr.update(choices=filtered, value=filtered) | |
def check_all(names_list): | |
return gr.update(value=names_list) | |
def uncheck_all(_): | |
return gr.update(value=[]) | |
def display_filtered_forecast(file, preset_filename, selected_names): | |
try: | |
# If no file uploaded and no valid preset chosen, return early | |
if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"): | |
return None, "No file selected." | |
# Load data | |
if file is not None: | |
df = load_table(file.name) | |
else: | |
df = load_table(preset_filename) | |
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all(): | |
all_names = df.iloc[:, 0].tolist() | |
data_only = df.iloc[:, 1:].astype(float) | |
else: | |
all_names = [f"Series {i}" for i in range(len(df))] | |
data_only = df.astype(float) | |
mask = [name in selected_names for name in all_names] | |
if not any(mask): | |
return None, "No timeseries chosen to plot." | |
filtered_data = data_only.iloc[mask, :].values | |
filtered_names = [all_names[i] for i, m in enumerate(mask) if m] | |
out = _forecast_tensor[mask] # slice forecasts to match filtered rows | |
inp = torch.tensor(filtered_data) | |
# If multiple series selected, create a subplot for each in a single figure | |
fig = go.Figure() | |
for idx in range(inp.shape[0]): | |
ts = inp[idx].numpy().tolist() | |
qp = out[idx].numpy() | |
series_name = filtered_names[idx] | |
x_hist = list(range(len(ts))) | |
# Historical data | |
fig.add_trace(go.Scatter( | |
x=x_hist, | |
y=ts, | |
mode='lines+markers', | |
name=f"{series_name} - Given Data" | |
)) | |
# Quantiles | |
x_pred = list(range(len(ts) - 1, len(ts) - 1 + qp.shape[0])) | |
for i in range(qp.shape[1]): | |
fig.add_trace(go.Scatter( | |
x=x_pred, | |
y=qp[:, i], | |
mode='lines', | |
name=f"{series_name} - Quantile {i+1}", | |
opacity=0.6 | |
)) | |
fig.update_layout( | |
title=dict(text="Forecasts for Selected Timeseries", x=0.5, font=dict(size=16, family="Arial", color="#000")), | |
xaxis_title="Time", | |
yaxis_title="Value", | |
hovermode='x unified' | |
) | |
return fig, "" | |
except Exception as e: | |
return None, f"Error: {e}. Use CSV, XLS, XLSX, or PARQUET." | |
# ---------------------------- | |
# Gradio layout: two columns + instructions | |
# ---------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# 📈 TiRex - timeseries forecasting 📊") | |
gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.") | |
with gr.Row(): | |
# Left column: controls | |
with gr.Column(): | |
gr.Markdown("## Data Selection") | |
file_input = gr.File( | |
label="Upload CSV / XLSX / PARQUET", | |
file_types=[".csv", ".xls", ".xlsx", ".parquet"] | |
) | |
preset_choices = ["-- No preset selected --", "stocks_data_noindex.csv", "stocks_data.csv"] | |
preset_dropdown = gr.Dropdown( | |
label="Or choose a preset:", | |
choices=preset_choices, | |
value="-- No preset selected --" | |
) | |
gr.Markdown("## Search / Filter") | |
search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')") | |
filter_checkbox = gr.CheckboxGroup( | |
choices=[], value=[], label="Select which timeseries to show" | |
) | |
with gr.Row(): | |
check_all_btn = gr.Button("✅ Check All") | |
uncheck_all_btn = gr.Button("❎ Uncheck All") | |
plot_button = gr.Button("▶️ Plot Forecasts") | |
errbox = gr.Textbox(label="Error Message", interactive=False) | |
with gr.Row(): | |
gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False) | |
gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False) | |
# Right column: interactive plot + instructions | |
with gr.Column(): | |
gr.Markdown("## Forecast Plot") | |
plot_output = gr.Plot() | |
# Instruction text below plot | |
gr.Markdown("## Instructions") | |
gr.Markdown( | |
""" | |
**How to format your data:** | |
- Your file must be a table (CSV, XLS, XLSX, or Parquet). | |
- **One row per timeseries.** Each row is treated as a separate series. | |
- If you want to **name** each series, put the name as the first value in **every** row: | |
- Example (CSV): | |
`AAPL, 120.5, 121.0, 119.8, ...` | |
`AMZN, 3300.0, 3310.5, 3295.2, ...` | |
- In that case, the first column is not numeric, so it will be used as the series name. | |
- If you do **not** want named series, simply leave out the first column entirely and have all values numeric: | |
- Example: | |
`120.5, 121.0, 119.8, ...` | |
`3300.0, 3310.5, 3295.2, ...` | |
- Then every row will be auto-named “Series 0, Series 1, …” in order. | |
- **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix. | |
- The rest of the columns (after the optional name) must be numeric data points for that series. | |
- You can filter by typing in the search box. Then check or uncheck individual names before plotting. | |
- Use “Check All” / “Uncheck All” to quickly select or deselect every series. | |
- Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead). | |
""" | |
) | |
names_state = gr.State([]) | |
# When file or preset changes, update names | |
file_input.change( | |
fn=extract_names_and_update, | |
inputs=[file_input, preset_dropdown], | |
outputs=[filter_checkbox, names_state] | |
) | |
preset_dropdown.change( | |
fn=extract_names_and_update, | |
inputs=[file_input, preset_dropdown], | |
outputs=[filter_checkbox, names_state] | |
) | |
# When search term changes, filter names | |
search_box.change( | |
fn=filter_names, | |
inputs=[search_box, names_state], | |
outputs=[filter_checkbox] | |
) | |
# Check All / Uncheck All | |
check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox) | |
uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox) | |
# Plot button | |
plot_button.click( | |
fn=display_filtered_forecast, | |
inputs=[file_input, preset_dropdown, filter_checkbox], | |
outputs=[plot_output, errbox] | |
) | |
demo.launch() | |
# ''' | |
# 1. Prepared datasets | |
# 2. Plots of different quiantilies (different colors) | |
# 3. Filters for plots... | |
# 4. Different input options | |
# 5. README.md in there (in UI) (contact us for fine-tuning) | |
# 6. Requirements for dimensions | |
# 7. LOGO of NX-AI and xLSTM and tirex | |
# 8. *Range of prediction length customizable | |
# 9. *Multivariate data (x_t is vector) | |
# ''' |