Spaces:
Running
Running
import dash | |
import dash_bootstrap_components as dbc | |
from dash import dcc, html, Input, Output, State, callback_context, ALL, ClientsideFunction | |
import plotly.graph_objects as go | |
from src.execution_model import ScheduleConfig, Schedule | |
from src.strategies import ( | |
generate_1f1b_schedule, | |
generate_zero_bubble_1p_schedule, | |
generate_1f1b_overlap_schedule, | |
generate_1f1b_interleave_schedule, | |
generate_1f1b_interleave_overlap_schedule, | |
generate_dualpipe_schedule | |
) | |
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure | |
STRATEGIES = { | |
"1f1b": generate_1f1b_schedule, | |
"zb1p": generate_zero_bubble_1p_schedule, | |
"1f1b_overlap": generate_1f1b_overlap_schedule, | |
"1f1b_interleave": generate_1f1b_interleave_schedule, | |
"1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule, | |
"dualpipe": generate_dualpipe_schedule, | |
} | |
# Strategy descriptions for better UX | |
STRATEGY_INFO = { | |
"1f1b": { | |
"name": "1F1B", | |
"description": "One Forward One Backward - Standard pipeline parallelism", | |
"icon": "bi-arrow-left-right" | |
}, | |
"zb1p": { | |
"name": "ZB-1P", | |
"description": "Zero Bubble 1-stage Pipeline - Minimizes pipeline bubbles", | |
"icon": "bi-circle" | |
}, | |
"1f1b_overlap": { | |
"name": "1F1B Overlap", | |
"description": "1F1B with overlapped forward and backward passes", | |
"icon": "bi-layers" | |
}, | |
"1f1b_interleave": { | |
"name": "1F1B Interleave", | |
"description": "Interleaved pipeline stages for better efficiency", | |
"icon": "bi-shuffle" | |
}, | |
"1f1b_interleave_overlap": { | |
"name": "1F1B Interleave + Overlap", | |
"description": "Combines interleaving with overlapped execution", | |
"icon": "bi-layers-fill" | |
}, | |
"dualpipe": { | |
"name": "DualPipe", | |
"description": "Dual pipeline execution for enhanced parallelism", | |
"icon": "bi-diagram-2" | |
} | |
} | |
app = dash.Dash( | |
__name__, | |
external_stylesheets=[ | |
dbc.themes.BOOTSTRAP, | |
dbc.icons.BOOTSTRAP, | |
"https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" | |
], | |
suppress_callback_exceptions=True | |
) | |
app.title = "Pipeline Parallelism Schedule Visualizer" | |
# Initial default values | |
default_values = { | |
"num_devices": 4, | |
"num_stages": 8, | |
"num_batches": 16, | |
"p2p_latency": 0.0, | |
"op_time_forward": 1.0, | |
"op_time_backward_d": 1.0, | |
"op_time_backward_w": 1.0, | |
"op_time_backward": 2.0, | |
"strategy": ["1f1b_interleave"], | |
"op_time_overlapped_fwd_bwd": None, | |
"microbatch_group_size_per_vp_stage": None, | |
} | |
# Define input groups using dbc components | |
card_style = {"marginBottom": "20px"} | |
# Header section | |
header = html.Div([ | |
html.Div([ | |
html.H1([ | |
html.I(className="bi bi-diagram-3 me-3"), | |
"Pipeline Parallelism Schedule Visualizer" | |
], className="text-center mb-0"), | |
html.P("Visualize and compare different pipeline parallelism scheduling strategies", | |
className="text-center mt-2 mb-0 lead") | |
], className="container") | |
], className="main-header") | |
# Basic parameters card with improved styling | |
basic_params_card = dbc.Card([ | |
dbc.CardBody([ | |
html.H5([ | |
html.I(className="bi bi-sliders section-icon"), | |
"Basic Configuration" | |
], className="section-title"), | |
dbc.Row([ | |
dbc.Col([ | |
dbc.Label("Number of Devices (GPUs)", html_for='num_devices', className="form-label"), | |
dbc.InputGroup([ | |
dbc.InputGroupText(html.I(className="bi bi-gpu-card")), | |
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], | |
min=1, step=1, required=True), | |
]), | |
dbc.FormFeedback("Please provide a positive integer.", type="invalid"), | |
], md=6), | |
dbc.Col([ | |
dbc.Label("Number of Stages", html_for='num_stages', className="form-label"), | |
dbc.InputGroup([ | |
dbc.InputGroupText(html.I(className="bi bi-stack")), | |
dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], | |
min=1, step=1, required=True), | |
]), | |
dbc.FormFeedback("Please provide a positive integer.", type="invalid"), | |
], md=6), | |
], className="mb-3"), | |
dbc.Row([ | |
dbc.Col([ | |
dbc.Label("Number of Microbatches", html_for='num_batches', className="form-label"), | |
dbc.InputGroup([ | |
dbc.InputGroupText(html.I(className="bi bi-collection")), | |
dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], | |
min=1, step=1, required=True), | |
]), | |
dbc.FormFeedback("Please provide a positive integer.", type="invalid"), | |
], md=12), | |
]), | |
]) | |
], style=card_style) | |
# Improved scheduling strategy selection | |
scheduling_params_card = dbc.Card([ | |
dbc.CardBody([ | |
html.H5([ | |
html.I(className="bi bi-diagram-2 section-icon"), | |
"Scheduling Strategy" | |
], className="section-title"), | |
html.P("Select one or more strategies to compare:", className="text-muted mb-3"), | |
dbc.Row([ | |
dbc.Col([ | |
html.Div([ | |
html.Div([ | |
html.I(className=f"{STRATEGY_INFO[strategy]['icon']} mb-2"), | |
html.H6(STRATEGY_INFO[strategy]['name'], className="mb-1"), | |
html.Small(STRATEGY_INFO[strategy]['description'], className="text-muted") | |
], | |
id={"type": "strategy-card", "index": strategy}, | |
className=f"strategy-card p-3 text-center {'selected' if strategy in default_values['strategy'] else ''}", | |
) | |
], className="mb-3") | |
], lg=4, md=6) for strategy in STRATEGIES.keys() | |
], className="g-3"), | |
dcc.Store(id='selected-strategies-store', data=default_values["strategy"]), | |
html.Div(id='strategy-selection-feedback', className='invalid-feedback d-block mt-2') | |
]) | |
], style=card_style) | |
# Timing parameters with better organization | |
timing_params_card = dbc.Card([ | |
dbc.CardBody([ | |
html.H5([ | |
html.I(className="bi bi-clock section-icon"), | |
"Operation Timing Configuration" | |
], className="section-title"), | |
# Basic timing parameters | |
dbc.Row([ | |
dbc.Col([ | |
html.Div([ | |
dbc.Label([ | |
"P2P Latency (ms)", | |
dbc.Badge("?", pill=True, color="secondary", className="ms-1", id="tooltip-p2p", | |
style={"cursor": "pointer", "fontSize": "0.75rem"}) | |
], html_for='p2p_latency', className="form-label"), | |
dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], | |
min=0, step=0.01, required=True), | |
dbc.FormFeedback("Must be ≥ 0", type="invalid"), | |
dbc.Tooltip( | |
"Time for point-to-point communication between adjacent devices.", | |
target="tooltip-p2p", | |
placement="right" | |
) | |
]) | |
], md=6), | |
dbc.Col([ | |
html.Div([ | |
dbc.Label([ | |
"Forward Pass Time (ms)", | |
dbc.Badge("?", pill=True, color="secondary", className="ms-1", id="tooltip-fwd", | |
style={"cursor": "pointer", "fontSize": "0.75rem"}) | |
], html_for='op_time_forward', className="form-label"), | |
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], | |
min=0.01, step=0.01, required=True), | |
dbc.FormFeedback("Must be > 0", type="invalid"), | |
dbc.Tooltip( | |
"Time for a forward pass of one microbatch through one stage.", | |
target="tooltip-fwd", | |
placement="right" | |
) | |
]) | |
], md=6), | |
], className="mb-3"), | |
# Backward timing | |
html.Div([ | |
dbc.Label([ | |
"Backward Pass Time (ms)", | |
dbc.Badge("?", pill=True, color="secondary", className="ms-1", id="tooltip-bwd", | |
style={"cursor": "pointer", "fontSize": "0.75rem"}) | |
], html_for='op_time_backward', className="form-label"), | |
dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], | |
min=0.01, step=0.01), | |
dbc.FormText("Combined backward pass time (data + weight gradients)", className="mt-1"), | |
dbc.FormFeedback("Must be > 0", type="invalid"), | |
dbc.Tooltip( | |
"Time for combined backward pass (data + weight gradients).", | |
target="tooltip-bwd", | |
placement="right" | |
) | |
], className="mb-3"), | |
# Advanced options with better styling | |
html.Hr(className="my-4"), | |
dbc.Button([ | |
html.I(className="bi bi-gear-fill me-2"), | |
"Advanced Timing Options" | |
], | |
id="advanced-timing-toggle", | |
color="light", | |
className="mb-3", | |
size="sm" | |
), | |
dbc.Collapse([ | |
dbc.Alert([ | |
html.I(className="bi bi-info-circle-fill me-2"), | |
"These options are for advanced users and specific scheduling strategies." | |
], color="info", className="mb-3"), | |
dbc.Row([ | |
dbc.Col([ | |
dbc.Label("Backward D (Data Gradient)", html_for='op_time_backward_d'), | |
dbc.Input(id='op_time_backward_d', type='number', | |
value=default_values["op_time_backward_d"], min=0.01, step=0.01), | |
dbc.FormText("For strategies with split backward"), | |
dbc.FormFeedback("Must be > 0", type="invalid"), | |
], md=6), | |
dbc.Col([ | |
dbc.Label("Backward W (Weight Gradient)", html_for='op_time_backward_w'), | |
dbc.Input(id='op_time_backward_w', type='number', | |
value=default_values["op_time_backward_w"], min=0.01, step=0.01), | |
dbc.FormText("For strategies with split backward"), | |
dbc.FormFeedback("Must be > 0", type="invalid"), | |
], md=6), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Overlapped Forward+Backward Time", html_for='op_time_overlapped_fwd_bwd'), | |
dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', | |
placeholder="Auto-calculated if not specified", min=0.01, step=0.01), | |
dbc.FormText("Time when forward and backward can be fully overlapped"), | |
dbc.FormFeedback("Must be > 0", type="invalid"), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Microbatch Group Size per VP Stage", html_for='microbatch_group_size_per_vp_stage'), | |
dbc.Input(id='microbatch_group_size_per_vp_stage', type='number', | |
placeholder=f"Defaults to number of devices", min=1, step=1), | |
dbc.FormText("For interleave strategies only"), | |
dbc.FormFeedback("Must be a positive integer", type="invalid"), | |
]), | |
], id="advanced-timing-collapse", is_open=False) | |
]) | |
], style=card_style) | |
# Updated app layout with improved structure | |
app.layout = html.Div([ | |
header, | |
dbc.Container([ | |
dbc.Row([ | |
# Left Column - Visualization Area | |
dbc.Col([ | |
dbc.Card([ | |
dbc.CardBody([ | |
html.H5([ | |
html.I(className="bi bi-graph-up section-icon"), | |
"Visualization Results" | |
], className="section-title"), | |
dcc.Loading( | |
id="loading-graph-area", | |
type="circle", | |
children=html.Div([ | |
# Welcome message for initial state | |
dbc.Alert([ | |
html.H4([ | |
html.I(className="bi bi-lightbulb me-2"), | |
"Welcome to Pipeline Parallelism Schedule Visualizer" | |
], className="alert-heading"), | |
html.Hr(), | |
html.P([ | |
"This tool helps you visualize and compare different pipeline parallelism scheduling strategies. ", | |
"To get started:" | |
], className="mb-3"), | |
html.Ol([ | |
html.Li("Configure your basic parameters (devices, stages, microbatches)"), | |
html.Li("Select one or more scheduling strategies to compare"), | |
html.Li("Set the operation timing parameters"), | |
html.Li("Click 'Generate Schedules' to visualize the results") | |
], className="mb-3"), | |
html.P([ | |
html.Strong("Tip: "), | |
"Hover over the ", | |
html.I(className="bi bi-question-circle"), | |
" icons for detailed explanations of each parameter." | |
], className="mb-0") | |
], color="info", className="text-start", id="welcome-message"), | |
], id='graph-output-container', style={"minHeight": "400px"}) | |
) | |
]) | |
]) | |
], lg=8, md=7, className="mb-4"), | |
# Right Column - Controls | |
dbc.Col([ | |
basic_params_card, | |
scheduling_params_card, | |
timing_params_card, | |
# Generate button with better styling | |
dbc.Button([ | |
html.I(className="bi bi-play-fill me-2"), | |
"Generate Schedules" | |
], | |
id='generate-button', | |
color="primary", | |
size="lg", | |
className="w-100 mt-3", | |
disabled=False | |
) | |
], lg=4, md=5) | |
]) | |
], fluid=True), | |
# Toast container | |
html.Div(id="toast-container", style={ | |
"position": "fixed", | |
"top": 20, | |
"right": 20, | |
"zIndex": 1050, | |
"maxWidth": "400px" | |
}), | |
# Footer | |
html.Footer([ | |
dbc.Container([ | |
html.Hr(className="mt-5"), | |
dbc.Row([ | |
dbc.Col([ | |
html.H6("About Pipeline Parallelism", className="text-muted mb-3"), | |
html.P([ | |
"Pipeline parallelism is a distributed training technique that splits a model across multiple devices. ", | |
"This visualizer helps you understand different scheduling strategies and their performance characteristics." | |
], className="small text-muted") | |
], md=4), | |
dbc.Col([ | |
html.H6("Scheduling Strategies", className="text-muted mb-3"), | |
html.Ul([ | |
html.Li("1F1B: Standard pipeline with one forward, one backward", className="small text-muted"), | |
html.Li("ZB-1P: Zero bubble optimization", className="small text-muted"), | |
html.Li("Interleave: Virtual pipeline stages", className="small text-muted"), | |
html.Li("Overlap: Concurrent forward/backward", className="small text-muted") | |
], className="list-unstyled") | |
], md=4), | |
dbc.Col([ | |
html.H6("Resources", className="text-muted mb-3"), | |
html.Div([ | |
html.A([ | |
html.I(className="bi bi-github me-2"), | |
"View on GitHub" | |
], href="#", className="small text-muted d-block mb-2"), | |
html.A([ | |
html.I(className="bi bi-book me-2"), | |
"Documentation" | |
], href="#", className="small text-muted d-block mb-2"), | |
html.A([ | |
html.I(className="bi bi-question-circle me-2"), | |
"Report an Issue" | |
], href="#", className="small text-muted d-block") | |
]) | |
], md=4) | |
]), | |
html.Hr(), | |
html.P([ | |
"© 2024 Pipeline Parallelism Schedule Visualizer. ", | |
"Built with ", | |
html.I(className="bi bi-heart-fill text-danger"), | |
" using Dash and Plotly." | |
], className="text-center text-muted small mb-3") | |
], fluid=True) | |
], className="mt-5 bg-light py-4") | |
]) | |
# Keep the existing store for backward compatibility | |
app.layout.children.append(dcc.Store(id='advanced-timing-switch', data=False)) | |
# --- Callback for Input Validation and Generate Button State --- | |
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency, | |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w, | |
op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage, selected_strategies): | |
is_invalid = { | |
"num_devices": num_devices is None or num_devices < 1, | |
"num_stages": num_stages is None or num_stages < 1, | |
"num_batches": num_batches is None or num_batches < 1, | |
"p2p_latency": p2p_latency is None or p2p_latency < 0, | |
"op_time_forward": op_time_forward is None or op_time_forward <= 0, | |
"op_time_backward": op_time_backward is not None and op_time_backward <= 0, | |
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0, | |
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0, | |
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0, | |
"microbatch_group_size_per_vp_stage": microbatch_group_size_per_vp_stage is not None and (microbatch_group_size_per_vp_stage < 1 or microbatch_group_size_per_vp_stage % 1 != 0), | |
} | |
# Validate strategy selection | |
strategy_feedback = "" # Default empty feedback | |
if not selected_strategies or len(selected_strategies) == 0: | |
is_invalid["strategies"] = True | |
strategy_feedback = "Please select at least one strategy." | |
else: | |
is_invalid["strategies"] = False | |
# Additional validation: Check if required timings are provided for selected strategies | |
needs_split_backward = any(s in ["zb1p", "dualpipe"] for s in selected_strategies) | |
needs_combined_backward = any(s not in ["zb1p", "dualpipe"] for s in selected_strategies) | |
if needs_split_backward and (op_time_backward_d is None or op_time_backward_w is None): | |
is_invalid["op_time_backward_d"] = op_time_backward_d is None or op_time_backward_d <= 0 | |
is_invalid["op_time_backward_w"] = op_time_backward_w is None or op_time_backward_w <= 0 | |
# We might want specific feedback here, but setting invalid=True is often enough | |
if needs_combined_backward and op_time_backward is None: | |
is_invalid["op_time_backward"] = op_time_backward is None or op_time_backward <= 0 | |
# Check if any input is invalid | |
overall_invalid = any(is_invalid.values()) | |
# Disable button if any validation fails | |
disable_button = overall_invalid | |
# Return button state and invalid states for each input | |
return ( | |
disable_button, | |
is_invalid["num_devices"], | |
is_invalid["num_stages"], | |
is_invalid["num_batches"], | |
is_invalid["p2p_latency"], | |
is_invalid["op_time_forward"], | |
is_invalid["op_time_backward"], | |
is_invalid["op_time_backward_d"], | |
is_invalid["op_time_backward_w"], | |
is_invalid["op_time_overlapped_fwd_bwd"], | |
is_invalid["microbatch_group_size_per_vp_stage"], | |
strategy_feedback # Update strategy feedback based on validation | |
) | |
# --- Callback to toggle Advanced Options Collapse --- | |
def toggle_advanced_options(n_clicks, is_open): | |
if n_clicks: | |
return not is_open | |
return is_open | |
# --- Client-side Callback for Strategy Card Selection --- | |
app.clientside_callback( | |
""" | |
function(n_clicks_list, current_strategies) { | |
const ctx = dash_clientside.callback_context; | |
if (!ctx.triggered || ctx.triggered.length === 0) { | |
return [dash_clientside.no_update, dash_clientside.no_update]; | |
} | |
const triggered = ctx.triggered[0]; | |
const clickedIndex = JSON.parse(triggered.prop_id.split('.')[0]).index; | |
let newStrategies = current_strategies ? [...current_strategies] : []; | |
if (newStrategies.includes(clickedIndex)) { | |
newStrategies = newStrategies.filter(s => s !== clickedIndex); | |
} else { | |
newStrategies.push(clickedIndex); | |
} | |
// Update card classes | |
const allStrategies = ['1f1b', 'zb1p', '1f1b_overlap', '1f1b_interleave', '1f1b_interleave_overlap', 'dualpipe']; | |
const cardClasses = allStrategies.map(strategy => | |
newStrategies.includes(strategy) | |
? 'strategy-card p-3 text-center selected' | |
: 'strategy-card p-3 text-center' | |
); | |
return [newStrategies, cardClasses]; | |
} | |
""", | |
Output('selected-strategies-store', 'data'), | |
Output({'type': 'strategy-card', 'index': ALL}, 'className'), | |
Input({'type': 'strategy-card', 'index': ALL}, 'n_clicks'), | |
State('selected-strategies-store', 'data'), | |
prevent_initial_call=True | |
) | |
# --- Main Graph Update Callback --- | |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency, | |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w, | |
op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage, | |
selected_strategies): | |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"] | |
graph_components = [] | |
toast_components = [] | |
valid_results = [] | |
error_messages = [] | |
automatic_adjustments = [] | |
execution_times = [] # Add list to store execution times | |
# Use a variable to track if initial validation fails | |
initial_validation_error = None | |
if not selected_strategies: | |
initial_validation_error = dbc.Toast( | |
"Please select at least one scheduling strategy.", | |
header="Input Error", | |
icon="warning", | |
duration=4000, | |
is_open=True, | |
className="border-warning" | |
) | |
elif not all([num_devices, num_stages, num_batches, op_time_forward]): | |
initial_validation_error = dbc.Toast( | |
"Missing required basic input values (Devices, Stages, Batches, Forward Time).", | |
header="Input Error", | |
icon="danger", | |
duration=4000, | |
is_open=True, | |
className="border-danger" | |
) | |
if initial_validation_error: | |
# Return empty graph list and the validation error toast | |
return [], [initial_validation_error] | |
for strategy in selected_strategies: | |
error_message = "" | |
placement_strategy = "" | |
# Use local copies of params that might be adjusted for this strategy | |
current_num_stages = num_stages | |
current_num_devices = num_devices | |
# Apply automatic adjustments for dualpipe | |
if strategy == "dualpipe" and num_stages != num_devices: | |
current_num_stages = num_devices | |
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices." | |
automatic_adjustments.append(adjustment_msg) | |
# Apply automatic adjustments for strategies that require num_stages == num_devices | |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices: | |
current_num_stages = num_devices | |
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices." | |
automatic_adjustments.append(adjustment_msg) | |
split_backward = strategy in ["zb1p", "dualpipe"] | |
if split_backward and not all([op_time_backward_d, op_time_backward_w]): | |
error_message = f"Strategy '{strategy}': Backward D and Backward W times are required." | |
elif not split_backward and not op_time_backward: | |
error_message = f"Strategy '{strategy}': Combined Backward time is required." | |
if not error_message: | |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]: | |
placement_strategy = "standard" | |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]: | |
placement_strategy = "interleave" | |
if current_num_stages % current_num_devices != 0: | |
error_message = f"Strategy '{strategy}': Requires Stages divisible by Devices." | |
elif strategy == "dualpipe": | |
placement_strategy = "dualpipe" | |
if current_num_stages % 2 != 0: | |
error_message = f"Strategy '{strategy}': Requires an even number of stages." | |
# Create adjusted operation times based on placement strategy | |
if not error_message: | |
try: | |
stages_per_device = current_num_stages // current_num_devices | |
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0 | |
if stages_per_device > 1: | |
adjustment_msg = f"Strategy '{strategy}': Op times scaled by 1/{stages_per_device} ({stages_per_device} stages/device)." | |
# Avoid adding duplicate adjustment messages if already added above | |
if adjustment_msg not in automatic_adjustments: | |
automatic_adjustments.append(adjustment_msg) | |
op_times = { "forward": float(op_time_forward) * time_scale_factor } | |
if split_backward: | |
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor | |
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor | |
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor | |
else: | |
op_times["backward"] = float(op_time_backward) * time_scale_factor | |
if op_time_overlapped_fwd_bwd is not None: | |
try: | |
overlapped_val = float(op_time_overlapped_fwd_bwd) | |
if overlapped_val > 0: | |
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor | |
except (ValueError, TypeError): | |
pass | |
config = ScheduleConfig( | |
num_devices=int(current_num_devices), | |
num_stages=int(current_num_stages), | |
num_batches=int(num_batches), | |
p2p_latency=float(p2p_latency), | |
placement_strategy=placement_strategy, | |
split_backward=split_backward, | |
op_times=op_times, | |
microbatch_group_size_per_vp_stage=int(microbatch_group_size_per_vp_stage) if microbatch_group_size_per_vp_stage is not None else None, | |
) | |
schedule_func = STRATEGIES.get(strategy) | |
if not schedule_func: | |
raise ValueError(f"Invalid strategy function for: {strategy}") | |
schedule = schedule_func(config) | |
schedule.execute() | |
vis_data = convert_schedule_to_visualization_format(schedule) | |
valid_results.append((strategy, schedule, vis_data)) | |
# Store execution time | |
execution_times.append((strategy, schedule.get_total_execution_time())) | |
except (AssertionError, ValueError, TypeError) as e: | |
error_message = f"Error for '{strategy}': {e}" | |
except Exception as e: | |
error_message = f"Unexpected error for '{strategy}': {e}" | |
if error_message: | |
error_messages.append((strategy, error_message)) | |
# --- Generate Toasts --- | |
# Add toasts for automatic adjustments | |
for adjustment in automatic_adjustments: | |
toast_components.append( | |
dbc.Toast( | |
adjustment, | |
header="Parameter Adjustment", | |
icon="info", | |
duration=5000, # Slightly longer duration for info | |
is_open=True, | |
className="border-info" | |
) | |
) | |
# Add toasts for errors | |
for strategy, msg in error_messages: | |
toast_components.append( | |
dbc.Toast( | |
msg, | |
header=f"Error: {strategy}", | |
icon="danger", | |
duration=8000, # Longer duration for errors | |
is_open=True, | |
className="border-danger" | |
) | |
) | |
# --- Generate Graphs with improved layout --- | |
if valid_results: | |
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results) | |
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf')) | |
# Create tabs for multiple strategies | |
tabs = [] | |
tab_panels = [] | |
for idx, (strategy, _, vis_data) in enumerate(sorted_valid_results): | |
strategy_info = STRATEGY_INFO[strategy] | |
# Create figure | |
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False) | |
margin = max_execution_time * 0.05 | |
fig.update_layout( | |
xaxis=dict(range=[0, max_execution_time + margin]), | |
paper_bgcolor='rgba(0,0,0,0)', | |
plot_bgcolor='rgba(0,0,0,0)', | |
font=dict(family="Inter, sans-serif"), | |
height=400, # Set explicit height | |
autosize=True, # Enable autosize | |
margin=dict(l=60, r=20, t=40, b=60), # Set proper margins | |
) | |
# Create tab | |
tab_label = strategy_info['name'] | |
tab_value = f"tab-{strategy}" | |
tabs.append( | |
dbc.Tab( | |
label=tab_label, | |
tab_id=tab_value, | |
activeTabClassName="fw-bold" | |
) | |
) | |
# Create tab panel content | |
tab_panels.append( | |
html.Div([ | |
dbc.Alert([ | |
html.I(className=f"{strategy_info['icon']} me-2"), | |
strategy_info['description'] | |
], color="light", className="mb-3"), | |
html.Div([ | |
dcc.Graph( | |
figure=fig, | |
config={ | |
'displayModeBar': False, | |
'responsive': True # Make graph responsive | |
}, | |
className="dash-graph" | |
) | |
], className="graph-container") | |
], | |
id=f"content-{strategy}", | |
style={"display": "block" if idx == 0 else "none"} # Show first tab by default | |
) | |
) | |
# Create tabbed interface with callback to switch content | |
graph_components = [ | |
dbc.Tabs( | |
tabs, | |
id="strategy-tabs", | |
active_tab=f"tab-{sorted_valid_results[0][0]}", | |
className="mb-3" | |
), | |
html.Div(tab_panels, id="tab-content-container") | |
] | |
# If there are graphs, use the component list, otherwise show a message | |
output_content = [] | |
if graph_components: | |
output_content = graph_components | |
elif not toast_components: | |
output_content = dbc.Alert([ | |
html.I(className="bi bi-info-circle-fill me-2"), | |
"Click 'Generate Schedules' to visualize pipeline parallelism strategies." | |
], color="info", className="text-center") | |
# Add execution time summary with improved design | |
if execution_times: | |
sorted_times = sorted(execution_times, key=lambda x: x[1]) | |
min_time = sorted_times[0][1] if sorted_times else None | |
# Create metric cards for top strategies | |
metric_cards = [] | |
for i, (strategy, time) in enumerate(sorted_times[:3]): # Show top 3 | |
strategy_info = STRATEGY_INFO[strategy] | |
badge_color = "success" if i == 0 else "primary" if i == 1 else "secondary" | |
metric_cards.append( | |
dbc.Col([ | |
html.Div([ | |
html.Div([ | |
html.I(className=f"{strategy_info['icon']} mb-2", | |
style={"fontSize": "2rem", "color": "#667eea"}), | |
html.H3(f"{time:.2f} ms", className="metric-value"), | |
html.P(strategy_info['name'], className="metric-label mb-0"), | |
dbc.Badge(f"#{i+1}", color=badge_color, className="mt-2") | |
], className="metric-card") | |
]) | |
], lg=4, md=6, className="mb-3") | |
) | |
# Create detailed comparison table | |
table_rows = [] | |
for strategy, time in sorted_times: | |
strategy_info = STRATEGY_INFO[strategy] | |
efficiency = (min_time / time * 100) if min_time else 100 | |
table_rows.append( | |
html.Tr([ | |
html.Td([ | |
html.I(className=f"{strategy_info['icon']} me-2"), | |
strategy_info['name'] | |
]), | |
html.Td(f"{time:.2f} ms"), | |
html.Td([ | |
dbc.Progress( | |
value=efficiency, | |
color="success" if efficiency >= 95 else "warning" if efficiency >= 80 else "danger", | |
className="mb-0", | |
style={"height": "10px"} | |
), | |
html.Small(f"{efficiency:.1f}%", className="ms-2 text-muted") | |
]) | |
], className="align-middle") | |
) | |
execution_summary = html.Div([ | |
html.H4([ | |
html.I(className="bi bi-speedometer2 section-icon"), | |
"Performance Summary" | |
], className="section-title mt-5"), | |
# Metric cards | |
dbc.Row(metric_cards, className="mb-4"), | |
# Detailed table | |
dbc.Card([ | |
dbc.CardBody([ | |
html.H5("Detailed Comparison", className="mb-3"), | |
dbc.Table([ | |
html.Thead([ | |
html.Tr([ | |
html.Th("Strategy"), | |
html.Th("Execution Time"), | |
html.Th("Relative Efficiency", style={"width": "40%"}) | |
]) | |
]), | |
html.Tbody(table_rows) | |
], hover=True, responsive=True, className="mb-0") | |
]) | |
]) | |
], className="execution-summary") | |
# Append the execution summary | |
if isinstance(output_content, list): | |
output_content.append(execution_summary) | |
else: | |
output_content = [output_content, execution_summary] | |
# Return graph components and toast components | |
return output_content, toast_components | |
# --- Client-side Callback for Tab Switching --- | |
app.clientside_callback( | |
""" | |
function(activeTab) { | |
if (!activeTab) return window.dash_clientside.no_update; | |
// Extract strategy from tab id (format: "tab-strategy") | |
const activeStrategy = activeTab.replace("tab-", ""); | |
// Get all tab content divs | |
const contentDivs = document.querySelectorAll('[id^="content-"]'); | |
contentDivs.forEach(div => { | |
const strategy = div.id.replace("content-", ""); | |
if (strategy === activeStrategy) { | |
div.style.display = "block"; | |
} else { | |
div.style.display = "none"; | |
} | |
}); | |
return window.dash_clientside.no_update; | |
} | |
""", | |
Output("tab-content-container", "children"), # Dummy output | |
Input("strategy-tabs", "active_tab"), | |
prevent_initial_call=True | |
) | |
# For Hugging Face Spaces deployment | |
server = app.server | |
if __name__ == '__main__': | |
app.run_server(debug=False, host='0.0.0.0', port=7860) |