Victarry's picture
Update UI.
7c49b8c
import dash
import dash_bootstrap_components as dbc
from dash import dcc, html, Input, Output, State, callback_context, ALL, ClientsideFunction
import plotly.graph_objects as go
from src.execution_model import ScheduleConfig, Schedule
from src.strategies import (
generate_1f1b_schedule,
generate_zero_bubble_1p_schedule,
generate_1f1b_overlap_schedule,
generate_1f1b_interleave_schedule,
generate_1f1b_interleave_overlap_schedule,
generate_dualpipe_schedule
)
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure
STRATEGIES = {
"1f1b": generate_1f1b_schedule,
"zb1p": generate_zero_bubble_1p_schedule,
"1f1b_overlap": generate_1f1b_overlap_schedule,
"1f1b_interleave": generate_1f1b_interleave_schedule,
"1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule,
"dualpipe": generate_dualpipe_schedule,
}
# Strategy descriptions for better UX
STRATEGY_INFO = {
"1f1b": {
"name": "1F1B",
"description": "One Forward One Backward - Standard pipeline parallelism",
"icon": "bi-arrow-left-right"
},
"zb1p": {
"name": "ZB-1P",
"description": "Zero Bubble 1-stage Pipeline - Minimizes pipeline bubbles",
"icon": "bi-circle"
},
"1f1b_overlap": {
"name": "1F1B Overlap",
"description": "1F1B with overlapped forward and backward passes",
"icon": "bi-layers"
},
"1f1b_interleave": {
"name": "1F1B Interleave",
"description": "Interleaved pipeline stages for better efficiency",
"icon": "bi-shuffle"
},
"1f1b_interleave_overlap": {
"name": "1F1B Interleave + Overlap",
"description": "Combines interleaving with overlapped execution",
"icon": "bi-layers-fill"
},
"dualpipe": {
"name": "DualPipe",
"description": "Dual pipeline execution for enhanced parallelism",
"icon": "bi-diagram-2"
}
}
app = dash.Dash(
__name__,
external_stylesheets=[
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
"https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap"
],
suppress_callback_exceptions=True
)
app.title = "Pipeline Parallelism Schedule Visualizer"
# Initial default values
default_values = {
"num_devices": 4,
"num_stages": 8,
"num_batches": 16,
"p2p_latency": 0.0,
"op_time_forward": 1.0,
"op_time_backward_d": 1.0,
"op_time_backward_w": 1.0,
"op_time_backward": 2.0,
"strategy": ["1f1b_interleave"],
"op_time_overlapped_fwd_bwd": None,
"microbatch_group_size_per_vp_stage": None,
}
# Define input groups using dbc components
card_style = {"marginBottom": "20px"}
# Header section
header = html.Div([
html.Div([
html.H1([
html.I(className="bi bi-diagram-3 me-3"),
"Pipeline Parallelism Schedule Visualizer"
], className="text-center mb-0"),
html.P("Visualize and compare different pipeline parallelism scheduling strategies",
className="text-center mt-2 mb-0 lead")
], className="container")
], className="main-header")
# Basic parameters card with improved styling
basic_params_card = dbc.Card([
dbc.CardBody([
html.H5([
html.I(className="bi bi-sliders section-icon"),
"Basic Configuration"
], className="section-title"),
dbc.Row([
dbc.Col([
dbc.Label("Number of Devices (GPUs)", html_for='num_devices', className="form-label"),
dbc.InputGroup([
dbc.InputGroupText(html.I(className="bi bi-gpu-card")),
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"],
min=1, step=1, required=True),
]),
dbc.FormFeedback("Please provide a positive integer.", type="invalid"),
], md=6),
dbc.Col([
dbc.Label("Number of Stages", html_for='num_stages', className="form-label"),
dbc.InputGroup([
dbc.InputGroupText(html.I(className="bi bi-stack")),
dbc.Input(id='num_stages', type='number', value=default_values["num_stages"],
min=1, step=1, required=True),
]),
dbc.FormFeedback("Please provide a positive integer.", type="invalid"),
], md=6),
], className="mb-3"),
dbc.Row([
dbc.Col([
dbc.Label("Number of Microbatches", html_for='num_batches', className="form-label"),
dbc.InputGroup([
dbc.InputGroupText(html.I(className="bi bi-collection")),
dbc.Input(id='num_batches', type='number', value=default_values["num_batches"],
min=1, step=1, required=True),
]),
dbc.FormFeedback("Please provide a positive integer.", type="invalid"),
], md=12),
]),
])
], style=card_style)
# Improved scheduling strategy selection
scheduling_params_card = dbc.Card([
dbc.CardBody([
html.H5([
html.I(className="bi bi-diagram-2 section-icon"),
"Scheduling Strategy"
], className="section-title"),
html.P("Select one or more strategies to compare:", className="text-muted mb-3"),
dbc.Row([
dbc.Col([
html.Div([
html.Div([
html.I(className=f"{STRATEGY_INFO[strategy]['icon']} mb-2"),
html.H6(STRATEGY_INFO[strategy]['name'], className="mb-1"),
html.Small(STRATEGY_INFO[strategy]['description'], className="text-muted")
],
id={"type": "strategy-card", "index": strategy},
className=f"strategy-card p-3 text-center {'selected' if strategy in default_values['strategy'] else ''}",
)
], className="mb-3")
], lg=4, md=6) for strategy in STRATEGIES.keys()
], className="g-3"),
dcc.Store(id='selected-strategies-store', data=default_values["strategy"]),
html.Div(id='strategy-selection-feedback', className='invalid-feedback d-block mt-2')
])
], style=card_style)
# Timing parameters with better organization
timing_params_card = dbc.Card([
dbc.CardBody([
html.H5([
html.I(className="bi bi-clock section-icon"),
"Operation Timing Configuration"
], className="section-title"),
# Basic timing parameters
dbc.Row([
dbc.Col([
html.Div([
dbc.Label([
"P2P Latency (ms)",
dbc.Badge("?", pill=True, color="secondary", className="ms-1", id="tooltip-p2p",
style={"cursor": "pointer", "fontSize": "0.75rem"})
], html_for='p2p_latency', className="form-label"),
dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"],
min=0, step=0.01, required=True),
dbc.FormFeedback("Must be ≥ 0", type="invalid"),
dbc.Tooltip(
"Time for point-to-point communication between adjacent devices.",
target="tooltip-p2p",
placement="right"
)
])
], md=6),
dbc.Col([
html.Div([
dbc.Label([
"Forward Pass Time (ms)",
dbc.Badge("?", pill=True, color="secondary", className="ms-1", id="tooltip-fwd",
style={"cursor": "pointer", "fontSize": "0.75rem"})
], html_for='op_time_forward', className="form-label"),
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"],
min=0.01, step=0.01, required=True),
dbc.FormFeedback("Must be > 0", type="invalid"),
dbc.Tooltip(
"Time for a forward pass of one microbatch through one stage.",
target="tooltip-fwd",
placement="right"
)
])
], md=6),
], className="mb-3"),
# Backward timing
html.Div([
dbc.Label([
"Backward Pass Time (ms)",
dbc.Badge("?", pill=True, color="secondary", className="ms-1", id="tooltip-bwd",
style={"cursor": "pointer", "fontSize": "0.75rem"})
], html_for='op_time_backward', className="form-label"),
dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"],
min=0.01, step=0.01),
dbc.FormText("Combined backward pass time (data + weight gradients)", className="mt-1"),
dbc.FormFeedback("Must be > 0", type="invalid"),
dbc.Tooltip(
"Time for combined backward pass (data + weight gradients).",
target="tooltip-bwd",
placement="right"
)
], className="mb-3"),
# Advanced options with better styling
html.Hr(className="my-4"),
dbc.Button([
html.I(className="bi bi-gear-fill me-2"),
"Advanced Timing Options"
],
id="advanced-timing-toggle",
color="light",
className="mb-3",
size="sm"
),
dbc.Collapse([
dbc.Alert([
html.I(className="bi bi-info-circle-fill me-2"),
"These options are for advanced users and specific scheduling strategies."
], color="info", className="mb-3"),
dbc.Row([
dbc.Col([
dbc.Label("Backward D (Data Gradient)", html_for='op_time_backward_d'),
dbc.Input(id='op_time_backward_d', type='number',
value=default_values["op_time_backward_d"], min=0.01, step=0.01),
dbc.FormText("For strategies with split backward"),
dbc.FormFeedback("Must be > 0", type="invalid"),
], md=6),
dbc.Col([
dbc.Label("Backward W (Weight Gradient)", html_for='op_time_backward_w'),
dbc.Input(id='op_time_backward_w', type='number',
value=default_values["op_time_backward_w"], min=0.01, step=0.01),
dbc.FormText("For strategies with split backward"),
dbc.FormFeedback("Must be > 0", type="invalid"),
], md=6),
], className="mb-3"),
html.Div([
dbc.Label("Overlapped Forward+Backward Time", html_for='op_time_overlapped_fwd_bwd'),
dbc.Input(id='op_time_overlapped_fwd_bwd', type='number',
placeholder="Auto-calculated if not specified", min=0.01, step=0.01),
dbc.FormText("Time when forward and backward can be fully overlapped"),
dbc.FormFeedback("Must be > 0", type="invalid"),
], className="mb-3"),
html.Div([
dbc.Label("Microbatch Group Size per VP Stage", html_for='microbatch_group_size_per_vp_stage'),
dbc.Input(id='microbatch_group_size_per_vp_stage', type='number',
placeholder=f"Defaults to number of devices", min=1, step=1),
dbc.FormText("For interleave strategies only"),
dbc.FormFeedback("Must be a positive integer", type="invalid"),
]),
], id="advanced-timing-collapse", is_open=False)
])
], style=card_style)
# Updated app layout with improved structure
app.layout = html.Div([
header,
dbc.Container([
dbc.Row([
# Left Column - Visualization Area
dbc.Col([
dbc.Card([
dbc.CardBody([
html.H5([
html.I(className="bi bi-graph-up section-icon"),
"Visualization Results"
], className="section-title"),
dcc.Loading(
id="loading-graph-area",
type="circle",
children=html.Div([
# Welcome message for initial state
dbc.Alert([
html.H4([
html.I(className="bi bi-lightbulb me-2"),
"Welcome to Pipeline Parallelism Schedule Visualizer"
], className="alert-heading"),
html.Hr(),
html.P([
"This tool helps you visualize and compare different pipeline parallelism scheduling strategies. ",
"To get started:"
], className="mb-3"),
html.Ol([
html.Li("Configure your basic parameters (devices, stages, microbatches)"),
html.Li("Select one or more scheduling strategies to compare"),
html.Li("Set the operation timing parameters"),
html.Li("Click 'Generate Schedules' to visualize the results")
], className="mb-3"),
html.P([
html.Strong("Tip: "),
"Hover over the ",
html.I(className="bi bi-question-circle"),
" icons for detailed explanations of each parameter."
], className="mb-0")
], color="info", className="text-start", id="welcome-message"),
], id='graph-output-container', style={"minHeight": "400px"})
)
])
])
], lg=8, md=7, className="mb-4"),
# Right Column - Controls
dbc.Col([
basic_params_card,
scheduling_params_card,
timing_params_card,
# Generate button with better styling
dbc.Button([
html.I(className="bi bi-play-fill me-2"),
"Generate Schedules"
],
id='generate-button',
color="primary",
size="lg",
className="w-100 mt-3",
disabled=False
)
], lg=4, md=5)
])
], fluid=True),
# Toast container
html.Div(id="toast-container", style={
"position": "fixed",
"top": 20,
"right": 20,
"zIndex": 1050,
"maxWidth": "400px"
}),
# Footer
html.Footer([
dbc.Container([
html.Hr(className="mt-5"),
dbc.Row([
dbc.Col([
html.H6("About Pipeline Parallelism", className="text-muted mb-3"),
html.P([
"Pipeline parallelism is a distributed training technique that splits a model across multiple devices. ",
"This visualizer helps you understand different scheduling strategies and their performance characteristics."
], className="small text-muted")
], md=4),
dbc.Col([
html.H6("Scheduling Strategies", className="text-muted mb-3"),
html.Ul([
html.Li("1F1B: Standard pipeline with one forward, one backward", className="small text-muted"),
html.Li("ZB-1P: Zero bubble optimization", className="small text-muted"),
html.Li("Interleave: Virtual pipeline stages", className="small text-muted"),
html.Li("Overlap: Concurrent forward/backward", className="small text-muted")
], className="list-unstyled")
], md=4),
dbc.Col([
html.H6("Resources", className="text-muted mb-3"),
html.Div([
html.A([
html.I(className="bi bi-github me-2"),
"View on GitHub"
], href="#", className="small text-muted d-block mb-2"),
html.A([
html.I(className="bi bi-book me-2"),
"Documentation"
], href="#", className="small text-muted d-block mb-2"),
html.A([
html.I(className="bi bi-question-circle me-2"),
"Report an Issue"
], href="#", className="small text-muted d-block")
])
], md=4)
]),
html.Hr(),
html.P([
"© 2024 Pipeline Parallelism Schedule Visualizer. ",
"Built with ",
html.I(className="bi bi-heart-fill text-danger"),
" using Dash and Plotly."
], className="text-center text-muted small mb-3")
], fluid=True)
], className="mt-5 bg-light py-4")
])
# Keep the existing store for backward compatibility
app.layout.children.append(dcc.Store(id='advanced-timing-switch', data=False))
# --- Callback for Input Validation and Generate Button State ---
@app.callback(
Output('generate-button', 'disabled'),
# Outputs to control the 'invalid' state of Inputs
Output('num_devices', 'invalid'),
Output('num_stages', 'invalid'),
Output('num_batches', 'invalid'),
Output('p2p_latency', 'invalid'),
Output('op_time_forward', 'invalid'),
Output('op_time_backward', 'invalid'),
Output('op_time_backward_d', 'invalid'),
Output('op_time_backward_w', 'invalid'),
Output('op_time_overlapped_fwd_bwd', 'invalid'),
Output('microbatch_group_size_per_vp_stage', 'invalid'),
# Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
# We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
# Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
Output('strategy-selection-feedback', 'children', allow_duplicate=True), # Update feedback from validation callback too
# Inputs: Trigger validation whenever any relevant input changes
Input('num_devices', 'value'),
Input('num_stages', 'value'),
Input('num_batches', 'value'),
Input('p2p_latency', 'value'),
Input('op_time_forward', 'value'),
Input('op_time_backward', 'value'),
Input('op_time_backward_d', 'value'),
Input('op_time_backward_w', 'value'),
Input('op_time_overlapped_fwd_bwd', 'value'),
Input('microbatch_group_size_per_vp_stage', 'value'),
Input('selected-strategies-store', 'data'), # Validate strategy selection
prevent_initial_call=True # Prevent callback running on page load before user interaction
)
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage, selected_strategies):
is_invalid = {
"num_devices": num_devices is None or num_devices < 1,
"num_stages": num_stages is None or num_stages < 1,
"num_batches": num_batches is None or num_batches < 1,
"p2p_latency": p2p_latency is None or p2p_latency < 0,
"op_time_forward": op_time_forward is None or op_time_forward <= 0,
"op_time_backward": op_time_backward is not None and op_time_backward <= 0,
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
"microbatch_group_size_per_vp_stage": microbatch_group_size_per_vp_stage is not None and (microbatch_group_size_per_vp_stage < 1 or microbatch_group_size_per_vp_stage % 1 != 0),
}
# Validate strategy selection
strategy_feedback = "" # Default empty feedback
if not selected_strategies or len(selected_strategies) == 0:
is_invalid["strategies"] = True
strategy_feedback = "Please select at least one strategy."
else:
is_invalid["strategies"] = False
# Additional validation: Check if required timings are provided for selected strategies
needs_split_backward = any(s in ["zb1p", "dualpipe"] for s in selected_strategies)
needs_combined_backward = any(s not in ["zb1p", "dualpipe"] for s in selected_strategies)
if needs_split_backward and (op_time_backward_d is None or op_time_backward_w is None):
is_invalid["op_time_backward_d"] = op_time_backward_d is None or op_time_backward_d <= 0
is_invalid["op_time_backward_w"] = op_time_backward_w is None or op_time_backward_w <= 0
# We might want specific feedback here, but setting invalid=True is often enough
if needs_combined_backward and op_time_backward is None:
is_invalid["op_time_backward"] = op_time_backward is None or op_time_backward <= 0
# Check if any input is invalid
overall_invalid = any(is_invalid.values())
# Disable button if any validation fails
disable_button = overall_invalid
# Return button state and invalid states for each input
return (
disable_button,
is_invalid["num_devices"],
is_invalid["num_stages"],
is_invalid["num_batches"],
is_invalid["p2p_latency"],
is_invalid["op_time_forward"],
is_invalid["op_time_backward"],
is_invalid["op_time_backward_d"],
is_invalid["op_time_backward_w"],
is_invalid["op_time_overlapped_fwd_bwd"],
is_invalid["microbatch_group_size_per_vp_stage"],
strategy_feedback # Update strategy feedback based on validation
)
# --- Callback to toggle Advanced Options Collapse ---
@app.callback(
Output("advanced-timing-collapse", "is_open"),
Input("advanced-timing-toggle", "n_clicks"),
State("advanced-timing-collapse", "is_open"),
prevent_initial_call=True,
)
def toggle_advanced_options(n_clicks, is_open):
if n_clicks:
return not is_open
return is_open
# --- Client-side Callback for Strategy Card Selection ---
app.clientside_callback(
"""
function(n_clicks_list, current_strategies) {
const ctx = dash_clientside.callback_context;
if (!ctx.triggered || ctx.triggered.length === 0) {
return [dash_clientside.no_update, dash_clientside.no_update];
}
const triggered = ctx.triggered[0];
const clickedIndex = JSON.parse(triggered.prop_id.split('.')[0]).index;
let newStrategies = current_strategies ? [...current_strategies] : [];
if (newStrategies.includes(clickedIndex)) {
newStrategies = newStrategies.filter(s => s !== clickedIndex);
} else {
newStrategies.push(clickedIndex);
}
// Update card classes
const allStrategies = ['1f1b', 'zb1p', '1f1b_overlap', '1f1b_interleave', '1f1b_interleave_overlap', 'dualpipe'];
const cardClasses = allStrategies.map(strategy =>
newStrategies.includes(strategy)
? 'strategy-card p-3 text-center selected'
: 'strategy-card p-3 text-center'
);
return [newStrategies, cardClasses];
}
""",
Output('selected-strategies-store', 'data'),
Output({'type': 'strategy-card', 'index': ALL}, 'className'),
Input({'type': 'strategy-card', 'index': ALL}, 'n_clicks'),
State('selected-strategies-store', 'data'),
prevent_initial_call=True
)
# --- Main Graph Update Callback ---
@app.callback(
# Output graph container and toast container separately
Output('graph-output-container', 'children'),
Output('toast-container', 'children'), # Output for toasts
Input('generate-button', 'n_clicks'),
State('num_devices', 'value'),
State('num_stages', 'value'),
State('num_batches', 'value'),
State('p2p_latency', 'value'),
State('op_time_forward', 'value'),
State('op_time_backward', 'value'),
State('op_time_backward_d', 'value'),
State('op_time_backward_w', 'value'),
State('op_time_overlapped_fwd_bwd', 'value'),
State('microbatch_group_size_per_vp_stage', 'value'),
State('selected-strategies-store', 'data'),
prevent_initial_call=True
)
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage,
selected_strategies):
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
graph_components = []
toast_components = []
valid_results = []
error_messages = []
automatic_adjustments = []
execution_times = [] # Add list to store execution times
# Use a variable to track if initial validation fails
initial_validation_error = None
if not selected_strategies:
initial_validation_error = dbc.Toast(
"Please select at least one scheduling strategy.",
header="Input Error",
icon="warning",
duration=4000,
is_open=True,
className="border-warning"
)
elif not all([num_devices, num_stages, num_batches, op_time_forward]):
initial_validation_error = dbc.Toast(
"Missing required basic input values (Devices, Stages, Batches, Forward Time).",
header="Input Error",
icon="danger",
duration=4000,
is_open=True,
className="border-danger"
)
if initial_validation_error:
# Return empty graph list and the validation error toast
return [], [initial_validation_error]
for strategy in selected_strategies:
error_message = ""
placement_strategy = ""
# Use local copies of params that might be adjusted for this strategy
current_num_stages = num_stages
current_num_devices = num_devices
# Apply automatic adjustments for dualpipe
if strategy == "dualpipe" and num_stages != num_devices:
current_num_stages = num_devices
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices."
automatic_adjustments.append(adjustment_msg)
# Apply automatic adjustments for strategies that require num_stages == num_devices
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
current_num_stages = num_devices
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices."
automatic_adjustments.append(adjustment_msg)
split_backward = strategy in ["zb1p", "dualpipe"]
if split_backward and not all([op_time_backward_d, op_time_backward_w]):
error_message = f"Strategy '{strategy}': Backward D and Backward W times are required."
elif not split_backward and not op_time_backward:
error_message = f"Strategy '{strategy}': Combined Backward time is required."
if not error_message:
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
placement_strategy = "standard"
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
placement_strategy = "interleave"
if current_num_stages % current_num_devices != 0:
error_message = f"Strategy '{strategy}': Requires Stages divisible by Devices."
elif strategy == "dualpipe":
placement_strategy = "dualpipe"
if current_num_stages % 2 != 0:
error_message = f"Strategy '{strategy}': Requires an even number of stages."
# Create adjusted operation times based on placement strategy
if not error_message:
try:
stages_per_device = current_num_stages // current_num_devices
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
if stages_per_device > 1:
adjustment_msg = f"Strategy '{strategy}': Op times scaled by 1/{stages_per_device} ({stages_per_device} stages/device)."
# Avoid adding duplicate adjustment messages if already added above
if adjustment_msg not in automatic_adjustments:
automatic_adjustments.append(adjustment_msg)
op_times = { "forward": float(op_time_forward) * time_scale_factor }
if split_backward:
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
else:
op_times["backward"] = float(op_time_backward) * time_scale_factor
if op_time_overlapped_fwd_bwd is not None:
try:
overlapped_val = float(op_time_overlapped_fwd_bwd)
if overlapped_val > 0:
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
except (ValueError, TypeError):
pass
config = ScheduleConfig(
num_devices=int(current_num_devices),
num_stages=int(current_num_stages),
num_batches=int(num_batches),
p2p_latency=float(p2p_latency),
placement_strategy=placement_strategy,
split_backward=split_backward,
op_times=op_times,
microbatch_group_size_per_vp_stage=int(microbatch_group_size_per_vp_stage) if microbatch_group_size_per_vp_stage is not None else None,
)
schedule_func = STRATEGIES.get(strategy)
if not schedule_func:
raise ValueError(f"Invalid strategy function for: {strategy}")
schedule = schedule_func(config)
schedule.execute()
vis_data = convert_schedule_to_visualization_format(schedule)
valid_results.append((strategy, schedule, vis_data))
# Store execution time
execution_times.append((strategy, schedule.get_total_execution_time()))
except (AssertionError, ValueError, TypeError) as e:
error_message = f"Error for '{strategy}': {e}"
except Exception as e:
error_message = f"Unexpected error for '{strategy}': {e}"
if error_message:
error_messages.append((strategy, error_message))
# --- Generate Toasts ---
# Add toasts for automatic adjustments
for adjustment in automatic_adjustments:
toast_components.append(
dbc.Toast(
adjustment,
header="Parameter Adjustment",
icon="info",
duration=5000, # Slightly longer duration for info
is_open=True,
className="border-info"
)
)
# Add toasts for errors
for strategy, msg in error_messages:
toast_components.append(
dbc.Toast(
msg,
header=f"Error: {strategy}",
icon="danger",
duration=8000, # Longer duration for errors
is_open=True,
className="border-danger"
)
)
# --- Generate Graphs with improved layout ---
if valid_results:
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf'))
# Create tabs for multiple strategies
tabs = []
tab_panels = []
for idx, (strategy, _, vis_data) in enumerate(sorted_valid_results):
strategy_info = STRATEGY_INFO[strategy]
# Create figure
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
margin = max_execution_time * 0.05
fig.update_layout(
xaxis=dict(range=[0, max_execution_time + margin]),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(family="Inter, sans-serif"),
height=400, # Set explicit height
autosize=True, # Enable autosize
margin=dict(l=60, r=20, t=40, b=60), # Set proper margins
)
# Create tab
tab_label = strategy_info['name']
tab_value = f"tab-{strategy}"
tabs.append(
dbc.Tab(
label=tab_label,
tab_id=tab_value,
activeTabClassName="fw-bold"
)
)
# Create tab panel content
tab_panels.append(
html.Div([
dbc.Alert([
html.I(className=f"{strategy_info['icon']} me-2"),
strategy_info['description']
], color="light", className="mb-3"),
html.Div([
dcc.Graph(
figure=fig,
config={
'displayModeBar': False,
'responsive': True # Make graph responsive
},
className="dash-graph"
)
], className="graph-container")
],
id=f"content-{strategy}",
style={"display": "block" if idx == 0 else "none"} # Show first tab by default
)
)
# Create tabbed interface with callback to switch content
graph_components = [
dbc.Tabs(
tabs,
id="strategy-tabs",
active_tab=f"tab-{sorted_valid_results[0][0]}",
className="mb-3"
),
html.Div(tab_panels, id="tab-content-container")
]
# If there are graphs, use the component list, otherwise show a message
output_content = []
if graph_components:
output_content = graph_components
elif not toast_components:
output_content = dbc.Alert([
html.I(className="bi bi-info-circle-fill me-2"),
"Click 'Generate Schedules' to visualize pipeline parallelism strategies."
], color="info", className="text-center")
# Add execution time summary with improved design
if execution_times:
sorted_times = sorted(execution_times, key=lambda x: x[1])
min_time = sorted_times[0][1] if sorted_times else None
# Create metric cards for top strategies
metric_cards = []
for i, (strategy, time) in enumerate(sorted_times[:3]): # Show top 3
strategy_info = STRATEGY_INFO[strategy]
badge_color = "success" if i == 0 else "primary" if i == 1 else "secondary"
metric_cards.append(
dbc.Col([
html.Div([
html.Div([
html.I(className=f"{strategy_info['icon']} mb-2",
style={"fontSize": "2rem", "color": "#667eea"}),
html.H3(f"{time:.2f} ms", className="metric-value"),
html.P(strategy_info['name'], className="metric-label mb-0"),
dbc.Badge(f"#{i+1}", color=badge_color, className="mt-2")
], className="metric-card")
])
], lg=4, md=6, className="mb-3")
)
# Create detailed comparison table
table_rows = []
for strategy, time in sorted_times:
strategy_info = STRATEGY_INFO[strategy]
efficiency = (min_time / time * 100) if min_time else 100
table_rows.append(
html.Tr([
html.Td([
html.I(className=f"{strategy_info['icon']} me-2"),
strategy_info['name']
]),
html.Td(f"{time:.2f} ms"),
html.Td([
dbc.Progress(
value=efficiency,
color="success" if efficiency >= 95 else "warning" if efficiency >= 80 else "danger",
className="mb-0",
style={"height": "10px"}
),
html.Small(f"{efficiency:.1f}%", className="ms-2 text-muted")
])
], className="align-middle")
)
execution_summary = html.Div([
html.H4([
html.I(className="bi bi-speedometer2 section-icon"),
"Performance Summary"
], className="section-title mt-5"),
# Metric cards
dbc.Row(metric_cards, className="mb-4"),
# Detailed table
dbc.Card([
dbc.CardBody([
html.H5("Detailed Comparison", className="mb-3"),
dbc.Table([
html.Thead([
html.Tr([
html.Th("Strategy"),
html.Th("Execution Time"),
html.Th("Relative Efficiency", style={"width": "40%"})
])
]),
html.Tbody(table_rows)
], hover=True, responsive=True, className="mb-0")
])
])
], className="execution-summary")
# Append the execution summary
if isinstance(output_content, list):
output_content.append(execution_summary)
else:
output_content = [output_content, execution_summary]
# Return graph components and toast components
return output_content, toast_components
# --- Client-side Callback for Tab Switching ---
app.clientside_callback(
"""
function(activeTab) {
if (!activeTab) return window.dash_clientside.no_update;
// Extract strategy from tab id (format: "tab-strategy")
const activeStrategy = activeTab.replace("tab-", "");
// Get all tab content divs
const contentDivs = document.querySelectorAll('[id^="content-"]');
contentDivs.forEach(div => {
const strategy = div.id.replace("content-", "");
if (strategy === activeStrategy) {
div.style.display = "block";
} else {
div.style.display = "none";
}
});
return window.dash_clientside.no_update;
}
""",
Output("tab-content-container", "children"), # Dummy output
Input("strategy-tabs", "active_tab"),
prevent_initial_call=True
)
# For Hugging Face Spaces deployment
server = app.server
if __name__ == '__main__':
app.run_server(debug=False, host='0.0.0.0', port=7860)