import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import argparse
from typing import List, Dict, Literal, Optional
from tqdm import tqdm
import base64
from src.execution_model import Schedule
def convert_schedule_to_visualization_format(schedule: Schedule):
"""
Converts a Schedule object to the format needed for visualization.
Returns:
Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
"""
# Make sure all operations have start and end times
for op in schedule.ops.values():
if op.start_time is None or op.end_time is None:
raise ValueError("Operations must have start and end times. Run ScheduleExecutor.execute() first.")
visualization_data = {}
# Organize operations by device
for device_id, device_queue in enumerate(schedule.dev_queues):
visualization_data[device_id] = []
for op in device_queue.ops:
visualization_data[device_id].append({
"type": op.op_type,
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
"stage": op.stage_id,
"start_time": op.start_time,
"duration": op.end_time - op.start_time
})
return visualization_data
def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True):
"""
Create a Plotly figure for pipeline parallelism scheduling.
Args:
schedule_data: Dictionary mapping device IDs to lists of tasks (converted from Schedule)
max_time: Optional maximum time to display
show_progress: Whether to show a progress bar
"""
# Find the number of devices
num_devices = len(schedule_data)
empty_color = "whitesmoke"
# Colors for task types
def get_color(op_type: str, stage_id: int):
# Base colors
forward_base_color = "royalblue"
backward_base_color = "lightgreen" # Changed from sandybrown to match your visualization
virtual_stage = stage_id // num_devices
if op_type == "forward":
if virtual_stage == 0:
return forward_base_color
else:
# Lighter shade for virtual_stage > 0
return "lightskyblue"
elif op_type == "backward":
if virtual_stage == 0:
return backward_base_color
else:
# Lighter shade for virtual_stage > 0
return "lightseagreen"
else:
raise ValueError(f"Invalid operation type: {op_type}")
# Find the maximum time in the schedule if not provided
if max_time is None:
max_time = 0
for device in schedule_data:
for task in schedule_data[device]:
end_time = task["start_time"] + task["duration"]
if end_time > max_time:
max_time = end_time
# Create a figure
fig = go.Figure()
# Initialize progress tracking
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
tasks_processed = 0
if show_progress:
progress_bar = tqdm(total=total_tasks + num_devices + 3, desc="Creating visualization")
# Create a custom y-axis with no gaps between devices
y_spacing = 1.0 # Use 1.0 for no gaps
# Add rectangles for each task
for device_idx, device in enumerate(schedule_data):
device_idx_reversed = num_devices - device_idx - 1
# Sort tasks by start time to ensure correct rendering
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
for task in sorted_tasks:
# Determine task color and text color
if task["type"] == "forward":
color = get_color(task["type"], task["stage"])
text_color = "white"
name = "Forward"
elif task["type"] == "backward":
color = get_color(task["type"], task["stage"])
text_color = "black"
name = "Backward"
else:
color = empty_color
text_color = "black"
name = "Unknown"
# Add rectangle for the task
start_time = task["start_time"]
duration = task["duration"]
# Calculate y positions with no gaps
y_pos = device_idx_reversed * y_spacing
# Create rectangle using shape
fig.add_shape(
type="rect",
x0=start_time,
y0=y_pos - 0.5,
x1=start_time + duration,
y1=y_pos + 0.5,
line=dict(color="black", width=0.5),
fillcolor=color,
layer="above",
)
# Add batch number text
fig.add_annotation(
x=start_time + duration / 2,
y=y_pos,
text=f"{task['batch']}", # Only show batch ID
showarrow=False,
font=dict(color=text_color, size=12, family="Arial, bold"), # Increased font size
)
# Add hover data with additional details
fig.add_trace(go.Scatter(
x=[start_time + duration / 2],
y=[y_pos],
mode='markers',
marker=dict(opacity=0), # Invisible marker
hoverinfo='text',
text=f"Batch: {task['batch']}
Stage: {task['stage']}
Type: {name}
Start: {task['start_time']:.2f}
End: {task['start_time'] + task['duration']:.2f}
Duration: {task['duration']:.2f}",
showlegend=False
))
# Update progress
if show_progress:
tasks_processed += 1
progress_bar.update(1)
# Add custom legend
legend_items = [
dict(name="Forward", color=get_color("forward", 0)),
dict(name="Backward", color=get_color("backward", 0)),
]
for i, item in enumerate(legend_items):
fig.add_trace(go.Scatter(
x=[None],
y=[None],
mode='markers',
marker=dict(size=10, color=item['color']),
name=item['name'],
showlegend=True
))
if show_progress and i < len(legend_items) - 1:
progress_bar.update(1)
# Set axis properties
device_labels = [f"Device {i}" for i in range(num_devices)]
device_labels.reverse() # Reverse to put Device 0 at the top
# Calculate tick positions with no gaps
tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
# Adjust the range to ensure there are no empty spaces at the end
x_end = max_time * 1.05 # Add a small margin
fig.update_layout(
yaxis=dict(
tickmode="array",
tickvals=tick_positions,
ticktext=device_labels,
showgrid=False,
zeroline=False,
),
margin=dict(l=50, r=20, t=40, b=40),
plot_bgcolor="white",
title=dict(
text="Pipeline Parallelism Schedule",
x=0.5,
y=0.98, # Move title position closer to the top
font=dict(size=20)
),
legend=dict(
orientation="h",
yanchor="top",
y=-0.1, # Position below the plot
xanchor="center",
x=0.5
),
width=1600,
height=400, # Reduce height to make the visualization more compact
bargap=0,
bargroupgap=0,
)
if show_progress:
progress_bar.update(1)
progress_bar.close()
return fig
def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
"""
Create a Dash app to visualize the pipeline schedule.
Args:
schedule: Schedule object to visualize
schedule_type: Type of schedule ("1f1b" or other)
"""
# Convert schedule to visualization format
schedule_data = convert_schedule_to_visualization_format(schedule)
# Create the app
app = dash.Dash(__name__, title=f"Pipeline Parallelism Visualizer - {schedule_type}")
app.layout = html.Div([
html.H1(f"Pipeline Parallelism Visualizer - {schedule_type}", style={'textAlign': 'center'}),
html.Div([
html.Div([
html.H3("Schedule Configuration:"),
html.Ul([
html.Li(f"Number of devices: {schedule.config.num_devices}"),
html.Li(f"Number of stages: {schedule.config.num_stages}"),
html.Li(f"Number of batches: {schedule.config.num_batches}"),
]),
], className="config-section"),
html.Button("Download Image", id="btn-download",
style={
'marginTop': '20px',
'padding': '10px',
'backgroundColor': '#007BFF',
'color': 'white',
'border': 'none',
'borderRadius': '5px',
'cursor': 'pointer'
}),
dcc.Download(id="download-image"),
], style={'margin': '20px'}),
html.Div(id="graph-container", children=[]),
dcc.Graph(
id="pipeline-graph",
config={'displayModeBar': True, 'toImageButtonOptions': {'format': 'png', 'filename': 'pipeline_visualization'}}
),
])
@app.callback(
Output("pipeline-graph", "figure"),
Input("graph-container", "children"),
prevent_initial_call=False,
)
def load_graph(_):
# Create the figure when the app loads
return create_pipeline_figure(schedule_data, show_progress=True)
@app.callback(
Output("download-image", "data"),
Input("btn-download", "n_clicks"),
prevent_initial_call=True,
)
def download_image(n_clicks):
# Generate the figure for download
fig = create_pipeline_figure(schedule_data, show_progress=True)
# Convert to base64 image
img_bytes = fig.to_image(format="png", width=1600, height=1000, scale=2)
img_base64 = base64.b64encode(img_bytes).decode('ascii')
# Return the download data
return dict(
content=img_base64,
filename=f"pipeline_visualization_{schedule_type}.png",
type="image/png",
base64=True
)
return app
def visualize_pipeline_parallelism_dash(
schedule: Schedule,
port: int = 8050,
debug: bool = False
):
"""
Launch a Dash app to visualize the pipeline schedule interactively.
Args:
schedule: Schedule object to visualize
port: Port to run the Dash app on
debug: Whether to run the Dash app in debug mode
"""
app = create_dash_app(schedule)
print(f"Starting Dash app on http://localhost:{port}/")
app.run_server(debug=debug, port=port)
def save_pipeline_visualization_plotly(
schedule: Schedule,
output_file: str = "pipeline_visualization_plotly.png",
):
"""
Save a static image of the pipeline schedule visualization.
Args:
schedule: Schedule object to visualize
output_file: Path to save the image to
"""
schedule_data = convert_schedule_to_visualization_format(schedule)
fig = create_pipeline_figure(schedule_data, show_progress=True)
print(f"Saving visualization to {output_file}...")
fig.write_image(output_file, width=1600, height=400, scale=2)
print(f"Visualization saved to {output_file}")