Spaces:
Running
Running
import dash | |
from dash import dcc, html | |
from dash.dependencies import Input, Output, State | |
import plotly.graph_objects as go | |
import numpy as np | |
from typing import List, Dict, Literal | |
from tqdm import tqdm | |
import time | |
def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_progress=True): | |
""" | |
Create a Plotly figure for pipeline parallelism scheduling. | |
Args: | |
schedule: Dictionary mapping device IDs to lists of tasks. | |
Each task is a dictionary with keys: | |
- 'type': 'forward', 'backward', or 'optimizer' | |
- 'batch': batch number | |
- 'start_time': start time of the task | |
- 'duration': duration of the task | |
max_time: Optional maximum time to display | |
show_progress: Whether to show a progress bar | |
""" | |
# Colors for task types | |
forward_color = "royalblue" | |
backward_color = "sandybrown" | |
optimizer_color = "#FFEFCF" | |
empty_color = "whitesmoke" | |
# Find the number of stages (devices) | |
num_stages = len(schedule) | |
# Find the maximum time in the schedule if not provided | |
if max_time is None: | |
max_time = 0 | |
for device in schedule: | |
for task in schedule[device]: | |
end_time = task["start_time"] + task["duration"] | |
if end_time > max_time: | |
max_time = end_time | |
# Create a figure | |
fig = go.Figure() | |
# Initialize progress tracking | |
total_tasks = sum(len(tasks) for tasks in schedule.values()) | |
tasks_processed = 0 | |
if show_progress: | |
progress_bar = tqdm(total=total_tasks + num_stages + 3, desc="Creating visualization") | |
# Add background for empty cells | |
for device_idx in range(num_stages): | |
device_idx_reversed = num_stages - device_idx - 1 # Reverse for plotting | |
fig.add_trace(go.Scatter( | |
x=[0, max_time], | |
y=[device_idx_reversed, device_idx_reversed], | |
mode='lines', | |
line=dict(color='lightgray', width=0.5), | |
showlegend=False, | |
hoverinfo='none' | |
)) | |
if show_progress: | |
progress_bar.update(1) | |
# Add rectangles for each task | |
for device_idx, device in enumerate(schedule): | |
device_idx_reversed = num_stages - device_idx - 1 | |
for task in schedule[device]: | |
# Determine task color and text color | |
if task["type"] == "forward": | |
color = forward_color | |
text_color = "white" | |
name = "Forward" | |
elif task["type"] == "backward": | |
color = backward_color | |
text_color = "black" | |
name = "Backward" | |
else: # optimizer or any other type | |
color = optimizer_color | |
text_color = "black" | |
name = "Optimizer step" | |
# Add rectangle for the task | |
start_time = task["start_time"] | |
duration = task["duration"] | |
# Create rectangle using shape | |
fig.add_shape( | |
type="rect", | |
x0=start_time, | |
y0=device_idx_reversed - 0.4, | |
x1=start_time + duration, | |
y1=device_idx_reversed + 0.4, | |
line=dict(color="black", width=0.5), | |
fillcolor=color, | |
layer="above", | |
) | |
# Add batch number text | |
fig.add_annotation( | |
x=start_time + duration / 2, | |
y=device_idx_reversed, | |
text=str(task["batch"]), | |
showarrow=False, | |
font=dict(color=text_color, size=10, family="Arial, bold"), | |
) | |
# Update progress | |
if show_progress: | |
tasks_processed += 1 | |
progress_bar.update(1) | |
# Add custom legend | |
legend_items = [ | |
dict(name="Forward", color=forward_color), | |
dict(name="Backward", color=backward_color), | |
dict(name="Optimizer step", color=optimizer_color) | |
] | |
for i, item in enumerate(legend_items): | |
fig.add_trace(go.Scatter( | |
x=[None], | |
y=[None], | |
mode='markers', | |
marker=dict(size=10, color=item['color']), | |
name=item['name'], | |
showlegend=True | |
)) | |
if show_progress and i < len(legend_items) - 1: | |
progress_bar.update(1) | |
# Set axis properties | |
device_labels = [f"Device {i+1}" for i in range(num_stages)] | |
device_labels.reverse() # Reverse to put Device 1 at the top | |
fig.update_layout( | |
xaxis=dict( | |
showticklabels=False, | |
showgrid=False, | |
zeroline=False, | |
title="Time →", | |
range=[0, max_time + 0.5] | |
), | |
yaxis=dict( | |
tickmode="array", | |
tickvals=list(range(num_stages)), | |
ticktext=device_labels, | |
showgrid=False, | |
zeroline=False, | |
range=[-0.5, num_stages - 0.5] | |
), | |
margin=dict(l=50, r=50, t=50, b=50), | |
plot_bgcolor="white", | |
legend=dict( | |
orientation="h", | |
yanchor="bottom", | |
y=-0.2, | |
xanchor="center", | |
x=0.5 | |
) | |
) | |
if show_progress: | |
progress_bar.update(1) # Final update for layout | |
progress_bar.close() | |
return fig | |
def create_dash_app(schedule: Dict[int, List[Dict]], schedule_type="1f1b"): | |
""" | |
Create a Dash app for interactive visualization of pipeline scheduling. | |
Args: | |
schedule: Dictionary mapping device IDs to lists of tasks | |
schedule_type: Type of scheduling algorithm used | |
""" | |
app = dash.Dash(__name__, title="Pipeline Parallelism Visualization") | |
app.layout = html.Div([ | |
html.H1(f"Pipeline Parallelism Visualization ({schedule_type.upper()})", | |
style={'textAlign': 'center'}), | |
html.Div(id="loading-container", children=[ | |
dcc.Loading( | |
id="loading-graph", | |
type="circle", | |
children=[ | |
html.Div(id="graph-container", children=[ | |
dcc.Graph( | |
id='pipeline-graph', | |
style={'height': '600px'} | |
) | |
]) | |
] | |
) | |
]), | |
html.Div([ | |
html.Button("Download PNG", id="btn-download", | |
style={'margin': '10px'}), | |
dcc.Download(id="download-image") | |
], style={'textAlign': 'center', 'marginTop': '20px'}) | |
]) | |
def load_graph(_): | |
# Create the figure when the app loads | |
return create_pipeline_figure(schedule, show_progress=True) | |
def download_image(n_clicks): | |
# Show progress in terminal for downloads | |
fig = create_pipeline_figure(schedule, show_progress=True) | |
img_bytes = fig.to_image(format="png", scale=3) | |
return dict( | |
content=img_bytes, | |
filename="pipeline_visualization.png" | |
) | |
return app | |
def visualize_pipeline_parallelism_dash( | |
schedule: Dict[int, List[Dict]], | |
schedule_type: Literal["simple", "1f1b"] = "1f1b", | |
port: int = 8050, | |
debug: bool = False | |
): | |
""" | |
Create an interactive Dash visualization for pipeline parallelism scheduling. | |
Args: | |
schedule: Dictionary mapping device IDs to lists of tasks | |
schedule_type: Type of scheduling algorithm used ("simple" or "1f1b") | |
port: Port number to run the Dash app | |
debug: Whether to run the app in debug mode | |
""" | |
app = create_dash_app(schedule, schedule_type) | |
print(f"Starting Dash app on http://localhost:{port}/") | |
app.run_server(debug=debug, port=port) | |
def save_pipeline_visualization_plotly( | |
schedule: Dict[int, List[Dict]], | |
schedule_type: Literal["simple", "1f1b"] = "1f1b", | |
output_file: str = "pipeline_visualization_plotly.png", | |
): | |
""" | |
Save a static Plotly visualization of pipeline parallelism scheduling. | |
Args: | |
schedule: Dictionary mapping device IDs to lists of tasks | |
schedule_type: Type of scheduling algorithm used | |
output_file: Path to save the visualization | |
""" | |
print(f"Creating visualization for {len(schedule)} devices...") | |
fig = create_pipeline_figure(schedule, show_progress=True) | |
# Update layout for static image | |
fig.update_layout( | |
title=f"Pipeline Parallelism Visualization ({schedule_type.upper()})", | |
title_x=0.5 | |
) | |
print(f"Saving image to {output_file}...") | |
# Save as image | |
fig.write_image(output_file, scale=3) | |
print(f"Visualization saved to {output_file}") | |
if __name__ == "__main__": | |
# Example usage | |
import argparse | |
from pipeline import create_1f1b_schedule | |
parser = argparse.ArgumentParser(description="Pipeline Parallelism Visualizer") | |
parser.add_argument("--num-stages", type=int, default=4, help="Number of pipeline stages") | |
parser.add_argument("--num-batches", type=int, default=8, help="Number of microbatches") | |
parser.add_argument("--interactive", action="store_true", help="Run interactive Dash app") | |
parser.add_argument("--port", type=int, default=8050, help="Port for Dash app") | |
parser.add_argument("--output", type=str, default="pipeline_visualization_plotly.png", help="Output file for static image") | |
args = parser.parse_args() | |
# Create an example schedule | |
forward_times = [1.0] * args.num_stages | |
backward_times = [2.0] * args.num_stages | |
schedule = create_1f1b_schedule( | |
num_stages=args.num_stages, | |
num_batches=args.num_batches, | |
forward_times=forward_times, | |
backward_times=backward_times, | |
) | |
if args.interactive: | |
visualize_pipeline_parallelism_dash(schedule, port=args.port) | |
else: | |
save_pipeline_visualization_plotly(schedule, output_file=args.output) |