PP-schedule-visualizer / dash_visualizer.py
Victarry's picture
Add dash backend visualizer.
e178784
raw
history blame
10.4 kB
import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
import numpy as np
from typing import List, Dict, Literal
from tqdm import tqdm
import time
def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_progress=True):
"""
Create a Plotly figure for pipeline parallelism scheduling.
Args:
schedule: Dictionary mapping device IDs to lists of tasks.
Each task is a dictionary with keys:
- 'type': 'forward', 'backward', or 'optimizer'
- 'batch': batch number
- 'start_time': start time of the task
- 'duration': duration of the task
max_time: Optional maximum time to display
show_progress: Whether to show a progress bar
"""
# Colors for task types
forward_color = "royalblue"
backward_color = "sandybrown"
optimizer_color = "#FFEFCF"
empty_color = "whitesmoke"
# Find the number of stages (devices)
num_stages = len(schedule)
# Find the maximum time in the schedule if not provided
if max_time is None:
max_time = 0
for device in schedule:
for task in schedule[device]:
end_time = task["start_time"] + task["duration"]
if end_time > max_time:
max_time = end_time
# Create a figure
fig = go.Figure()
# Initialize progress tracking
total_tasks = sum(len(tasks) for tasks in schedule.values())
tasks_processed = 0
if show_progress:
progress_bar = tqdm(total=total_tasks + num_stages + 3, desc="Creating visualization")
# Add background for empty cells
for device_idx in range(num_stages):
device_idx_reversed = num_stages - device_idx - 1 # Reverse for plotting
fig.add_trace(go.Scatter(
x=[0, max_time],
y=[device_idx_reversed, device_idx_reversed],
mode='lines',
line=dict(color='lightgray', width=0.5),
showlegend=False,
hoverinfo='none'
))
if show_progress:
progress_bar.update(1)
# Add rectangles for each task
for device_idx, device in enumerate(schedule):
device_idx_reversed = num_stages - device_idx - 1
for task in schedule[device]:
# Determine task color and text color
if task["type"] == "forward":
color = forward_color
text_color = "white"
name = "Forward"
elif task["type"] == "backward":
color = backward_color
text_color = "black"
name = "Backward"
else: # optimizer or any other type
color = optimizer_color
text_color = "black"
name = "Optimizer step"
# Add rectangle for the task
start_time = task["start_time"]
duration = task["duration"]
# Create rectangle using shape
fig.add_shape(
type="rect",
x0=start_time,
y0=device_idx_reversed - 0.4,
x1=start_time + duration,
y1=device_idx_reversed + 0.4,
line=dict(color="black", width=0.5),
fillcolor=color,
layer="above",
)
# Add batch number text
fig.add_annotation(
x=start_time + duration / 2,
y=device_idx_reversed,
text=str(task["batch"]),
showarrow=False,
font=dict(color=text_color, size=10, family="Arial, bold"),
)
# Update progress
if show_progress:
tasks_processed += 1
progress_bar.update(1)
# Add custom legend
legend_items = [
dict(name="Forward", color=forward_color),
dict(name="Backward", color=backward_color),
dict(name="Optimizer step", color=optimizer_color)
]
for i, item in enumerate(legend_items):
fig.add_trace(go.Scatter(
x=[None],
y=[None],
mode='markers',
marker=dict(size=10, color=item['color']),
name=item['name'],
showlegend=True
))
if show_progress and i < len(legend_items) - 1:
progress_bar.update(1)
# Set axis properties
device_labels = [f"Device {i+1}" for i in range(num_stages)]
device_labels.reverse() # Reverse to put Device 1 at the top
fig.update_layout(
xaxis=dict(
showticklabels=False,
showgrid=False,
zeroline=False,
title="Time →",
range=[0, max_time + 0.5]
),
yaxis=dict(
tickmode="array",
tickvals=list(range(num_stages)),
ticktext=device_labels,
showgrid=False,
zeroline=False,
range=[-0.5, num_stages - 0.5]
),
margin=dict(l=50, r=50, t=50, b=50),
plot_bgcolor="white",
legend=dict(
orientation="h",
yanchor="bottom",
y=-0.2,
xanchor="center",
x=0.5
)
)
if show_progress:
progress_bar.update(1) # Final update for layout
progress_bar.close()
return fig
def create_dash_app(schedule: Dict[int, List[Dict]], schedule_type="1f1b"):
"""
Create a Dash app for interactive visualization of pipeline scheduling.
Args:
schedule: Dictionary mapping device IDs to lists of tasks
schedule_type: Type of scheduling algorithm used
"""
app = dash.Dash(__name__, title="Pipeline Parallelism Visualization")
app.layout = html.Div([
html.H1(f"Pipeline Parallelism Visualization ({schedule_type.upper()})",
style={'textAlign': 'center'}),
html.Div(id="loading-container", children=[
dcc.Loading(
id="loading-graph",
type="circle",
children=[
html.Div(id="graph-container", children=[
dcc.Graph(
id='pipeline-graph',
style={'height': '600px'}
)
])
]
)
]),
html.Div([
html.Button("Download PNG", id="btn-download",
style={'margin': '10px'}),
dcc.Download(id="download-image")
], style={'textAlign': 'center', 'marginTop': '20px'})
])
@app.callback(
Output("pipeline-graph", "figure"),
Input("graph-container", "children"),
prevent_initial_call=False,
)
def load_graph(_):
# Create the figure when the app loads
return create_pipeline_figure(schedule, show_progress=True)
@app.callback(
Output("download-image", "data"),
Input("btn-download", "n_clicks"),
prevent_initial_call=True,
)
def download_image(n_clicks):
# Show progress in terminal for downloads
fig = create_pipeline_figure(schedule, show_progress=True)
img_bytes = fig.to_image(format="png", scale=3)
return dict(
content=img_bytes,
filename="pipeline_visualization.png"
)
return app
def visualize_pipeline_parallelism_dash(
schedule: Dict[int, List[Dict]],
schedule_type: Literal["simple", "1f1b"] = "1f1b",
port: int = 8050,
debug: bool = False
):
"""
Create an interactive Dash visualization for pipeline parallelism scheduling.
Args:
schedule: Dictionary mapping device IDs to lists of tasks
schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
port: Port number to run the Dash app
debug: Whether to run the app in debug mode
"""
app = create_dash_app(schedule, schedule_type)
print(f"Starting Dash app on http://localhost:{port}/")
app.run_server(debug=debug, port=port)
def save_pipeline_visualization_plotly(
schedule: Dict[int, List[Dict]],
schedule_type: Literal["simple", "1f1b"] = "1f1b",
output_file: str = "pipeline_visualization_plotly.png",
):
"""
Save a static Plotly visualization of pipeline parallelism scheduling.
Args:
schedule: Dictionary mapping device IDs to lists of tasks
schedule_type: Type of scheduling algorithm used
output_file: Path to save the visualization
"""
print(f"Creating visualization for {len(schedule)} devices...")
fig = create_pipeline_figure(schedule, show_progress=True)
# Update layout for static image
fig.update_layout(
title=f"Pipeline Parallelism Visualization ({schedule_type.upper()})",
title_x=0.5
)
print(f"Saving image to {output_file}...")
# Save as image
fig.write_image(output_file, scale=3)
print(f"Visualization saved to {output_file}")
if __name__ == "__main__":
# Example usage
import argparse
from pipeline import create_1f1b_schedule
parser = argparse.ArgumentParser(description="Pipeline Parallelism Visualizer")
parser.add_argument("--num-stages", type=int, default=4, help="Number of pipeline stages")
parser.add_argument("--num-batches", type=int, default=8, help="Number of microbatches")
parser.add_argument("--interactive", action="store_true", help="Run interactive Dash app")
parser.add_argument("--port", type=int, default=8050, help="Port for Dash app")
parser.add_argument("--output", type=str, default="pipeline_visualization_plotly.png", help="Output file for static image")
args = parser.parse_args()
# Create an example schedule
forward_times = [1.0] * args.num_stages
backward_times = [2.0] * args.num_stages
schedule = create_1f1b_schedule(
num_stages=args.num_stages,
num_batches=args.num_batches,
forward_times=forward_times,
backward_times=backward_times,
)
if args.interactive:
visualize_pipeline_parallelism_dash(schedule, port=args.port)
else:
save_pipeline_visualization_plotly(schedule, output_file=args.output)