Spaces:
Running
Running
from pathlib import Path | |
from typing import Literal | |
import numpy as np | |
import pandas as pd | |
import plotly.colors as pcolors | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import streamlit as st | |
from plotly.subplots import make_subplots | |
from mlip_arena.models import REGISTRY | |
st.title("Stability") | |
DATA_DIR = Path(__file__).parents[2] / "benchmarks" / "stability" | |
st.markdown("### Methods") | |
container = st.container(border=True) | |
# Filter models that have valid parquet results | |
valid_models = [ | |
model | |
for model, metadata in REGISTRY.items() | |
if ( | |
DATA_DIR / REGISTRY[str(model)]["family"].lower() / f"{model}-heating.parquet" | |
).exists() | |
] | |
models = container.multiselect( | |
"MLIPs", | |
valid_models, | |
[ | |
"MACE-MP(M)", | |
"CHGNet", | |
"SevenNet", | |
"ORBv2", | |
"eqV2(OMat)", | |
"M3GNet", | |
"MatterSim", | |
"MACE-MPA", | |
], | |
) | |
st.markdown("### Settings") | |
vis = st.container(border=True) | |
# Build available color palettes from Plotly | |
color_palettes = { | |
attr: getattr(pcolors.qualitative, attr) | |
for attr in dir(pcolors.qualitative) | |
if isinstance(getattr(pcolors.qualitative, attr), list) | |
} | |
color_palettes.pop("__all__", None) | |
palette_name = vis.selectbox( | |
"Color sequence", options=list(color_palettes.keys()), index=22 | |
) | |
color_sequence = color_palettes[palette_name] | |
if not models: | |
st.stop() | |
def get_data(model_list, run_type: Literal["heating", "compression"]) -> pd.DataFrame: | |
"""Load parquet files for selected models.""" | |
dfs = [] | |
for m in model_list: | |
fpath = ( | |
DATA_DIR / REGISTRY[str(m)]["family"].lower() / f"{m}-{run_type}.parquet" | |
) | |
if not fpath.exists(): | |
continue | |
df_local = pd.read_parquet(fpath) | |
df_local["method"] = str(m) | |
dfs.append(df_local) | |
return pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame() | |
df_nvt = get_data(models, run_type="heating") | |
df_npt = get_data(models, run_type="compression") | |
# Map model → color | |
method_color_mapping = { | |
method: color_sequence[i % len(color_sequence)] | |
for i, method in enumerate(df_nvt["method"].unique()) | |
} | |
def prepare_scatter_df(df_in: pd.DataFrame, max_points: int = 20000) -> pd.DataFrame: | |
"""Prepare scatter dataframe with marker sizes scaled by total steps.""" | |
dfp = df_in.dropna(subset=["natoms", "steps_per_second"]).copy() | |
if dfp.empty: | |
return dfp | |
# Downsample if too many points | |
if len(dfp) > max_points: | |
dfp = dfp.sample(max_points, random_state=1) | |
if "total_steps" in dfp.columns: | |
ts_local = dfp["total_steps"].fillna(dfp["total_steps"].median()).astype(float) | |
ts_range = ts_local.max() - ts_local.min() | |
scaled = (ts_local - ts_local.min()) / (ts_range if ts_range != 0 else 1.0) | |
dfp["_marker_size"] = (scaled * 40) + 5 | |
else: | |
dfp["_marker_size"] = 8 | |
return dfp | |
def compute_power_law_fits(df_in: pd.DataFrame) -> dict: | |
"""Fit power-law scaling: steps/s ~ a * N^(-n).""" | |
fits = {} | |
for name, grp in df_in.groupby("method"): | |
grp_clean = grp.dropna(subset=["natoms", "steps_per_second"]) | |
grp_clean = grp_clean[ | |
(grp_clean["natoms"] > 0) & (grp_clean["steps_per_second"] > 0) | |
] | |
if len(grp_clean) < 3: | |
continue | |
try: | |
logsx = np.log(grp_clean["natoms"].astype(float)) | |
logsy = np.log(grp_clean["steps_per_second"].astype(float)) | |
slope, intercept = np.polyfit(logsx, logsy, 1) | |
fits[name] = (float(np.exp(intercept)), float(-slope)) # (a, n) | |
except Exception: | |
continue | |
return fits | |
def build_speed_figure( | |
df_in: pd.DataFrame, color_map: dict, show_scatter: bool | |
) -> go.Figure: | |
"""Build scatter plot of inference speed vs number of atoms with power-law fits.""" | |
fig = go.Figure() | |
# Optionally add scatter points | |
if show_scatter: | |
dfp = prepare_scatter_df(df_in) | |
scatter_fig = px.scatter( | |
dfp, | |
x="natoms", | |
y="steps_per_second", | |
color="method", | |
size="_marker_size", | |
hover_data=[c for c in ["material_id", "formula"] if c in dfp.columns], | |
color_discrete_map=color_map, | |
log_x=True, | |
log_y=True, | |
render_mode="webgl", | |
labels={ | |
"steps_per_second": "Steps per second", | |
"natoms": "Number of atoms", | |
}, | |
) | |
for trace in scatter_fig.data: | |
fig.add_trace(trace) | |
# Overlay fits | |
fits = compute_power_law_fits(df_in) | |
for method, (a, n) in fits.items(): | |
grp = df_in[df_in["method"] == method] | |
if grp["natoms"].dropna().empty: | |
continue | |
xs = np.logspace( | |
np.log10(grp["natoms"].min()), np.log10(grp["natoms"].max()), 200 | |
) | |
ys = a * xs ** (-n) | |
fig.add_trace( | |
go.Scatter( | |
x=xs, | |
y=ys, | |
mode="lines", | |
line=dict(color=color_map.get(method, "black"), width=2), | |
showlegend=not show_scatter, | |
name=f"{method}", | |
# zorder=0, | |
# text=hover_text, | |
# hoverinfo='text', # use the custom text | |
) | |
) | |
fig.update_layout( | |
height=520, | |
title="Inference speed (steps/s)", | |
xaxis=dict(type="log", title="Number of atoms"), | |
yaxis=dict(type="log", title="Steps per second"), | |
) | |
return fig | |
def build_nvt_figure( | |
df_in: pd.DataFrame, color_map: dict, show_scatter: bool | |
) -> go.Figure: | |
"""Build subplot: NVT valid runs (cumulative) + speed scaling plot.""" | |
fig = make_subplots( | |
rows=1, | |
cols=2, | |
column_widths=[0.4, 0.6], | |
subplot_titles=("Valid runs", "Inference speed: steps/s vs N"), | |
) | |
# Right panel: speed scaling | |
speed_fig = build_speed_figure(df_in, color_map, show_scatter) | |
for trace in speed_fig.data: | |
fig.add_trace(trace, row=1, col=2) | |
# Left panel: cumulative valid runs | |
for method, df_model in df_in.groupby("method"): | |
df_model_grp = df_model.drop_duplicates(["formula"]) | |
hist, bin_edges = np.histogram( | |
df_model_grp["normalized_final_step"], bins=np.linspace(0, 1, 50) | |
) | |
cumulative_population = np.cumsum(hist) | |
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 | |
fig.add_trace( | |
go.Scatter( | |
x=bin_centers[:-1], | |
y=(cumulative_population[-1] - cumulative_population[:-1]) / 120 * 100, | |
mode="lines", | |
line=dict(color=color_map.get(method)), | |
name=str(method), | |
showlegend=False, | |
), | |
row=1, | |
col=1, | |
) | |
fig.update_xaxes(title_text="Normalized time", row=1, col=1, range=[0, 1]) | |
fig.update_yaxes(title_text="Valid runs (%)", row=1, col=1) | |
fig.update_xaxes(type="log", row=1, col=2, title_text="Number of atoms") | |
fig.update_yaxes(type="log", row=1, col=2, title_text="Steps per second") | |
fig.update_layout(height=520, width=1000) | |
return fig | |
def build_npt_figure( | |
df_in: pd.DataFrame, color_map: dict, show_scatter: bool | |
) -> go.Figure: | |
"""Build subplot: NPT valid runs (cumulative) + speed scaling plot.""" | |
fig = make_subplots( | |
rows=1, | |
cols=2, | |
column_widths=[0.4, 0.6], | |
subplot_titles=("Valid runs", "Inference speed: steps/s vs N"), | |
) | |
# Right panel: speed scaling | |
speed_fig = build_speed_figure(df_in, color_map, show_scatter) | |
for trace in speed_fig.data: | |
fig.add_trace(trace, row=1, col=2) | |
# Left panel: cumulative valid runs | |
for method, df_model in df_in.groupby("method"): | |
df_model_grp = df_model.drop_duplicates(["formula"]) | |
hist, bin_edges = np.histogram( | |
df_model_grp["normalized_final_step"], bins=np.linspace(0, 1, 50) | |
) | |
cumulative_population = np.cumsum(hist) | |
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 | |
fig.add_trace( | |
go.Scatter( | |
x=bin_centers[:-1], | |
y=(cumulative_population[-1] - cumulative_population[:-1]) / 80 * 100, | |
mode="lines", | |
line=dict(color=color_map.get(method)), | |
name=str(method), | |
showlegend=False, | |
), | |
row=1, | |
col=1, | |
) | |
fig.update_xaxes(title_text="Normalized time", row=1, col=1, range=[0, 1]) | |
fig.update_yaxes(title_text="Valid runs (%)", row=1, col=1) | |
fig.update_xaxes(type="log", row=1, col=2, title_text="Number of atoms") | |
fig.update_yaxes(type="log", row=1, col=2, title_text="Steps per second") | |
fig.update_layout(height=520, width=1000) | |
return fig | |
if df_nvt.empty and df_npt.empty: | |
st.info("No data available to display for selected models.") | |
else: | |
st.markdown(""" | |
## Heating | |
Isochoric-isothermal (NVT) MD simulations on RM24 structures, with temperature ramp from 300K to 3000K over 10 ps. | |
""") | |
show_scatter_nvt = st.toggle( | |
"Show scatter points", key="show_scatter_nvt", value=True | |
) | |
# Toggle for scatter points | |
# show_scatter = vis.checkbox("Show scatter points", value=True) | |
st.plotly_chart( | |
build_nvt_figure(df_nvt, method_color_mapping, show_scatter_nvt), | |
use_container_width=True, | |
) | |
st.markdown(""" | |
## Compression | |
Isothermal-isobaric (NPT) MD simulations on RM24 structures, with pressure ramp from 0 GPa to 500 GPa and temperature ramp from 300K to 3000K over 10 ps. | |
""") | |
show_scatter_npt = st.toggle( | |
"Show scatter points", key="show_scatter_npt", value=True | |
) | |
# Toggle for scatter points | |
# show_scatter = vis.checkbox("Show scatter points", value=True) | |
st.plotly_chart( | |
build_npt_figure(df_npt, method_color_mapping, show_scatter_npt), | |
use_container_width=True, | |
) | |