Spaces:
Running
Running
import dash | |
import dash_bootstrap_components as dbc | |
from dash import dcc, html, Input, Output, State, callback_context | |
import plotly.graph_objects as go | |
import webbrowser | |
from threading import Timer | |
from src.execution_model import ScheduleConfig, Schedule | |
from src.strategies import ( | |
generate_1f1b_schedule, | |
generate_zero_bubble_1p_schedule, | |
generate_1f1b_overlap_schedule, | |
generate_1f1b_interleave_schedule, | |
generate_1f1b_interleave_overlap_schedule, | |
generate_dualpipe_schedule | |
) | |
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure | |
def open_browser(port): | |
webbrowser.open_new(f"http://127.0.0.1:{port}") | |
STRATEGIES = { | |
"1f1b": generate_1f1b_schedule, | |
"zb1p": generate_zero_bubble_1p_schedule, | |
"1f1b_overlap": generate_1f1b_overlap_schedule, | |
"1f1b_interleave": generate_1f1b_interleave_schedule, | |
"1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule, | |
"dualpipe": generate_dualpipe_schedule, | |
} | |
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True) | |
app.title = "Pipeline Parallelism Schedule Visualizer" | |
# Initial default values | |
default_values = { | |
"num_devices": 4, | |
"num_stages": 8, | |
"num_batches": 16, | |
"p2p_latency": 0.1, | |
"op_time_forward": 1.0, | |
"op_time_backward_d": 1.0, | |
"op_time_backward_w": 1.0, | |
"op_time_backward": 2.0, | |
"strategy": "1f1b_interleave", | |
"split_backward": False, | |
"placement_strategy": "interleave" | |
} | |
# Define input groups using dbc components | |
basic_params_card = dbc.Card( | |
dbc.CardBody([ | |
html.H5("Basic Parameters", className="card-title"), | |
html.Div([ | |
dbc.Label("Number of Devices (GPUs):"), | |
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Number of Stages (Model Chunks):"), | |
dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Number of Microbatches:"), | |
dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("P2P Latency (ms):"), | |
dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01), | |
], className="mb-3"), | |
]) | |
) | |
scheduling_params_card = dbc.Card( | |
dbc.CardBody([ | |
html.H5("Scheduling Parameters", className="card-title"), | |
html.Div([ | |
dbc.Label("Scheduling Strategies:"), | |
dbc.Checklist( | |
id='strategy-checklist', | |
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()], | |
value=[default_values["strategy"]], | |
inline=False, | |
), | |
], className="mb-3"), | |
]) | |
) | |
timing_params_card = dbc.Card( | |
dbc.CardBody([ | |
html.H5("Operation Timing (ms)", className="card-title"), | |
html.Div([ | |
dbc.Label("Forward:"), | |
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Backward (Combined):"), | |
dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01), | |
dbc.FormText("Used when strategy does NOT require split backward."), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Backward D (Data Grad):"), | |
dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01), | |
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."), | |
], className="mb-3"), | |
html.Div([ | |
dbc.Label("Backward W (Weight Grad):"), | |
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01), | |
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."), | |
], className="mb-3"), | |
]) | |
) | |
# Updated app layout using dbc components and structure | |
app.layout = dbc.Container([ | |
html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"), | |
dbc.Row([ | |
dbc.Col(basic_params_card, md=4), | |
dbc.Col(scheduling_params_card, md=4), | |
dbc.Col(timing_params_card, md=4), | |
]), | |
dbc.Row([ | |
dbc.Col([ | |
dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"), | |
], className="text-center") | |
]), | |
dbc.Row([ | |
dbc.Col([ | |
dcc.Loading( | |
id="loading-graph-area", | |
type="circle", | |
children=html.Div(id='graph-output-container', className="mt-4") | |
) | |
]) | |
]) | |
], fluid=True) | |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency, | |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w, | |
selected_strategies): | |
output_components = [] | |
if not selected_strategies: | |
return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")] | |
if not all([num_devices, num_stages, num_batches, op_time_forward]): | |
return [dbc.Alert("Missing required basic input values (Devices, Stages, Batches, Forward Time).", color="danger")] | |
for strategy in selected_strategies: | |
error_message = "" | |
fig = go.Figure() | |
placement_strategy = "" | |
split_backward = strategy in ["zb1p", "dualpipe"] | |
if split_backward and not all([op_time_backward_d, op_time_backward_w]): | |
error_message = f"Strategy '{strategy}': Backward D and Backward W times are required." | |
elif not split_backward and not op_time_backward: | |
error_message = f"Strategy '{strategy}': Combined Backward time is required." | |
if not error_message: | |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]: | |
placement_strategy = "standard" | |
if num_devices != num_stages: | |
error_message = f"Strategy '{strategy}': Requires Number of Stages == Number of Devices." | |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]: | |
placement_strategy = "interleave" | |
if num_stages % num_devices != 0: | |
error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices." | |
elif strategy == "dualpipe": | |
placement_strategy = "dualpipe" | |
if num_stages % 2 != 0: | |
error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages." | |
elif num_stages != num_devices: | |
error_message = f"Strategy '{strategy}' (DualPipe): Requires Number of Stages == Number of Devices." | |
if not error_message: | |
try: | |
op_times = { "forward": float(op_time_forward) } | |
if split_backward: | |
op_times["backward_D"] = float(op_time_backward_d) | |
op_times["backward_W"] = float(op_time_backward_w) | |
op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w) | |
else: | |
op_times["backward"] = float(op_time_backward) | |
config = ScheduleConfig( | |
num_devices=int(num_devices), | |
num_stages=int(num_stages), | |
num_batches=int(num_batches), | |
p2p_latency=float(p2p_latency), | |
placement_strategy=placement_strategy, | |
split_backward=split_backward, | |
op_times=op_times, | |
) | |
schedule_func = STRATEGIES.get(strategy) | |
if not schedule_func: | |
raise ValueError(f"Invalid strategy function for: {strategy}") | |
schedule = schedule_func(config) | |
schedule.execute() | |
vis_data = convert_schedule_to_visualization_format(schedule) | |
fig = create_pipeline_figure(vis_data, show_progress=False) | |
output_components.append(html.Div([ | |
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"), | |
dcc.Graph(figure=fig) | |
])) | |
except (AssertionError, ValueError, TypeError) as e: | |
error_message = f"Error generating schedule for '{strategy}': {e}" | |
import traceback | |
traceback.print_exc() | |
except Exception as e: | |
error_message = f"An unexpected error occurred for '{strategy}': {e}" | |
import traceback | |
traceback.print_exc() | |
if error_message: | |
output_components.append( | |
dbc.Alert(error_message, color="danger", className="mt-3") | |
) | |
return output_components | |
if __name__ == '__main__': | |
port = 8050 | |
# Timer(1, open_browser, args=(port,)).start() # Optional: automatically open browser | |
print(f"Dash server running on http://127.0.0.1:{port}") | |
app.run_server(debug=True, port=port) |