Spaces:
Running
Running
Improve UI components.
Browse files- pyproject.toml +2 -1
- src/server.py +165 -141
pyproject.toml
CHANGED
@@ -24,6 +24,7 @@ dependencies = [
|
|
24 |
"pandas>=2.1.0",
|
25 |
"numpy>=1.26.0",
|
26 |
"tqdm>=4.67.0",
|
|
|
27 |
]
|
28 |
|
29 |
[project.optional-dependencies]
|
@@ -64,4 +65,4 @@ disallow_incomplete_defs = true
|
|
64 |
|
65 |
[tool.pytest]
|
66 |
testpaths = ["tests"]
|
67 |
-
pythonpath = ["."]
|
|
|
24 |
"pandas>=2.1.0",
|
25 |
"numpy>=1.26.0",
|
26 |
"tqdm>=4.67.0",
|
27 |
+
"dash-bootstrap-components>=1.7.1",
|
28 |
]
|
29 |
|
30 |
[project.optional-dependencies]
|
|
|
65 |
|
66 |
[tool.pytest]
|
67 |
testpaths = ["tests"]
|
68 |
+
pythonpath = ["."]
|
src/server.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import dash
|
|
|
2 |
from dash import dcc, html, Input, Output, State, callback_context
|
3 |
import plotly.graph_objects as go
|
4 |
import webbrowser
|
@@ -27,8 +28,8 @@ STRATEGIES = {
|
|
27 |
"dualpipe": generate_dualpipe_schedule,
|
28 |
}
|
29 |
|
30 |
-
app = dash.Dash(__name__, suppress_callback_exceptions=True)
|
31 |
-
app.title = "Pipeline Parallelism Visualizer"
|
32 |
|
33 |
# Initial default values
|
34 |
default_values = {
|
@@ -45,91 +46,98 @@ default_values = {
|
|
45 |
"placement_strategy": "interleave"
|
46 |
}
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
html.Div([
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
|
|
67 |
html.Div([
|
68 |
-
|
69 |
-
|
70 |
-
id='strategy',
|
71 |
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
|
72 |
-
value=default_values["strategy"],
|
73 |
-
|
74 |
-
style={'width': '100%'}
|
75 |
-
),
|
76 |
-
|
77 |
-
html.Label("Placement Strategy:"),
|
78 |
-
dcc.Dropdown(
|
79 |
-
id='placement_strategy',
|
80 |
-
options=[
|
81 |
-
{'label': 'Standard', 'value': 'standard'},
|
82 |
-
{'label': 'Interleave', 'value': 'interleave'},
|
83 |
-
{'label': 'DualPipe', 'value': 'dualpipe'}
|
84 |
-
],
|
85 |
-
value=default_values["placement_strategy"],
|
86 |
-
clearable=False,
|
87 |
-
style={'width': '100%'}
|
88 |
),
|
|
|
|
|
|
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
options=[{'label': ' Split Backward Pass (for ZB-1P, DualPipe)', 'value': 'True'}],
|
94 |
-
value=['True'] if default_values["split_backward"] else [],
|
95 |
-
style={'display': 'inline-block'}
|
96 |
-
),
|
97 |
-
], style={'marginTop': '20px'}),
|
98 |
-
|
99 |
-
], style={'padding': 10, 'flex': 1}),
|
100 |
-
|
101 |
html.Div([
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
|
|
|
|
|
119 |
]),
|
120 |
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
129 |
|
130 |
@app.callback(
|
131 |
-
Output('
|
132 |
-
Output('error-message', 'children'),
|
133 |
Input('generate-button', 'n_clicks'),
|
134 |
State('num_devices', 'value'),
|
135 |
State('num_stages', 'value'),
|
@@ -139,83 +147,99 @@ app.layout = html.Div([
|
|
139 |
State('op_time_backward', 'value'),
|
140 |
State('op_time_backward_d', 'value'),
|
141 |
State('op_time_backward_w', 'value'),
|
142 |
-
State('strategy', 'value'),
|
143 |
-
State('split_backward', 'value'),
|
144 |
-
State('placement_strategy', 'value'),
|
145 |
prevent_initial_call=True
|
146 |
)
|
147 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
148 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
149 |
-
|
150 |
|
151 |
-
|
152 |
-
fig = go.Figure()
|
153 |
|
154 |
-
|
|
|
155 |
|
156 |
-
# Basic Validations
|
157 |
if not all([num_devices, num_stages, num_batches, op_time_forward]):
|
158 |
-
|
159 |
-
if split_backward and not all([op_time_backward_d, op_time_backward_w]):
|
160 |
-
return fig, "Backward D and Backward W times are required when 'Split Backward' is checked."
|
161 |
-
if not split_backward and not op_time_backward:
|
162 |
-
return fig, "Backward time is required when 'Split Backward' is unchecked."
|
163 |
-
if num_stages % num_devices != 0 and placement_strategy != 'dualpipe':
|
164 |
-
return fig, "Number of Stages must be divisible by Number of Devices for standard/interleave placement."
|
165 |
-
if placement_strategy == 'dualpipe' and num_stages % 2 != 0:
|
166 |
-
return fig, "DualPipe requires an even number of stages."
|
167 |
-
if placement_strategy == 'dualpipe' and num_stages != num_devices:
|
168 |
-
return fig, "DualPipe requires Number of Stages to be equal to Number of Devices."
|
169 |
-
if strategy == 'dualpipe' and not split_backward:
|
170 |
-
return fig, "DualPipe strategy currently requires 'Split Backward' to be checked."
|
171 |
-
if strategy == 'dualpipe' and placement_strategy != 'dualpipe':
|
172 |
-
return fig, "DualPipe strategy requires 'DualPipe' placement strategy."
|
173 |
-
if strategy == 'zb1p' and not split_backward:
|
174 |
-
return fig, "ZB-1P strategy requires 'Split Backward' to be checked."
|
175 |
-
|
176 |
-
try:
|
177 |
-
op_times = {
|
178 |
-
"forward": float(op_time_forward),
|
179 |
-
}
|
180 |
-
if split_backward:
|
181 |
-
op_times["backward_D"] = float(op_time_backward_d)
|
182 |
-
op_times["backward_W"] = float(op_time_backward_w)
|
183 |
-
# Add combined backward time for compatibility if needed by some visualization or calculation
|
184 |
-
op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w)
|
185 |
-
else:
|
186 |
-
op_times["backward"] = float(op_time_backward)
|
187 |
-
|
188 |
-
config = ScheduleConfig(
|
189 |
-
num_devices=int(num_devices),
|
190 |
-
num_stages=int(num_stages),
|
191 |
-
num_batches=int(num_batches),
|
192 |
-
p2p_latency=float(p2p_latency),
|
193 |
-
placement_strategy=placement_strategy,
|
194 |
-
split_backward=split_backward,
|
195 |
-
op_times=op_times,
|
196 |
-
)
|
197 |
-
|
198 |
-
schedule_func = STRATEGIES.get(strategy)
|
199 |
-
if not schedule_func:
|
200 |
-
raise ValueError(f"Invalid strategy selected: {strategy}")
|
201 |
-
|
202 |
-
schedule = schedule_func(config)
|
203 |
-
schedule.execute() # Calculate start/end times
|
204 |
-
|
205 |
-
vis_data = convert_schedule_to_visualization_format(schedule)
|
206 |
-
fig = create_pipeline_figure(vis_data, show_progress=False) # Disable progress bar in server mode
|
207 |
-
|
208 |
-
except AssertionError as e:
|
209 |
-
error_message = f"Configuration Error: {e}"
|
210 |
-
fig = go.Figure() # Return empty figure on error
|
211 |
-
except ValueError as e:
|
212 |
-
error_message = f"Input Error: {e}"
|
213 |
-
fig = go.Figure()
|
214 |
-
except Exception as e:
|
215 |
-
error_message = f"An unexpected error occurred: {e}"
|
216 |
-
fig = go.Figure()
|
217 |
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
if __name__ == '__main__':
|
221 |
port = 8050
|
|
|
1 |
import dash
|
2 |
+
import dash_bootstrap_components as dbc
|
3 |
from dash import dcc, html, Input, Output, State, callback_context
|
4 |
import plotly.graph_objects as go
|
5 |
import webbrowser
|
|
|
28 |
"dualpipe": generate_dualpipe_schedule,
|
29 |
}
|
30 |
|
31 |
+
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
|
32 |
+
app.title = "Pipeline Parallelism Schedule Visualizer"
|
33 |
|
34 |
# Initial default values
|
35 |
default_values = {
|
|
|
46 |
"placement_strategy": "interleave"
|
47 |
}
|
48 |
|
49 |
+
# Define input groups using dbc components
|
50 |
+
basic_params_card = dbc.Card(
|
51 |
+
dbc.CardBody([
|
52 |
+
html.H5("Basic Parameters", className="card-title"),
|
53 |
html.Div([
|
54 |
+
dbc.Label("Number of Devices (GPUs):"),
|
55 |
+
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1),
|
56 |
+
], className="mb-3"),
|
57 |
+
html.Div([
|
58 |
+
dbc.Label("Number of Stages (Model Chunks):"),
|
59 |
+
dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1),
|
60 |
+
], className="mb-3"),
|
61 |
+
html.Div([
|
62 |
+
dbc.Label("Number of Microbatches:"),
|
63 |
+
dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1),
|
64 |
+
], className="mb-3"),
|
65 |
+
html.Div([
|
66 |
+
dbc.Label("P2P Latency (ms):"),
|
67 |
+
dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01),
|
68 |
+
], className="mb-3"),
|
69 |
+
])
|
70 |
+
)
|
71 |
|
72 |
+
scheduling_params_card = dbc.Card(
|
73 |
+
dbc.CardBody([
|
74 |
+
html.H5("Scheduling Parameters", className="card-title"),
|
75 |
html.Div([
|
76 |
+
dbc.Label("Scheduling Strategies:"),
|
77 |
+
dbc.Checklist(
|
78 |
+
id='strategy-checklist',
|
79 |
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
|
80 |
+
value=[default_values["strategy"]],
|
81 |
+
inline=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
),
|
83 |
+
], className="mb-3"),
|
84 |
+
])
|
85 |
+
)
|
86 |
|
87 |
+
timing_params_card = dbc.Card(
|
88 |
+
dbc.CardBody([
|
89 |
+
html.H5("Operation Timing (ms)", className="card-title"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
html.Div([
|
91 |
+
dbc.Label("Forward:"),
|
92 |
+
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01),
|
93 |
+
], className="mb-3"),
|
94 |
+
html.Div([
|
95 |
+
dbc.Label("Backward (Combined):"),
|
96 |
+
dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
|
97 |
+
dbc.FormText("Used when strategy does NOT require split backward."),
|
98 |
+
], className="mb-3"),
|
99 |
+
html.Div([
|
100 |
+
dbc.Label("Backward D (Data Grad):"),
|
101 |
+
dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
|
102 |
+
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
103 |
+
], className="mb-3"),
|
104 |
+
html.Div([
|
105 |
+
dbc.Label("Backward W (Weight Grad):"),
|
106 |
+
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
|
107 |
+
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
108 |
+
], className="mb-3"),
|
109 |
+
])
|
110 |
+
)
|
111 |
|
112 |
+
# Updated app layout using dbc components and structure
|
113 |
+
app.layout = dbc.Container([
|
114 |
+
html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
|
115 |
|
116 |
+
dbc.Row([
|
117 |
+
dbc.Col(basic_params_card, md=4),
|
118 |
+
dbc.Col(scheduling_params_card, md=4),
|
119 |
+
dbc.Col(timing_params_card, md=4),
|
120 |
]),
|
121 |
|
122 |
+
dbc.Row([
|
123 |
+
dbc.Col([
|
124 |
+
dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"),
|
125 |
+
], className="text-center")
|
126 |
+
]),
|
127 |
|
128 |
+
dbc.Row([
|
129 |
+
dbc.Col([
|
130 |
+
dcc.Loading(
|
131 |
+
id="loading-graph-area",
|
132 |
+
type="circle",
|
133 |
+
children=html.Div(id='graph-output-container', className="mt-4")
|
134 |
+
)
|
135 |
+
])
|
136 |
+
])
|
137 |
+
], fluid=True)
|
138 |
|
139 |
@app.callback(
|
140 |
+
Output('graph-output-container', 'children'),
|
|
|
141 |
Input('generate-button', 'n_clicks'),
|
142 |
State('num_devices', 'value'),
|
143 |
State('num_stages', 'value'),
|
|
|
147 |
State('op_time_backward', 'value'),
|
148 |
State('op_time_backward_d', 'value'),
|
149 |
State('op_time_backward_w', 'value'),
|
150 |
+
State('strategy-checklist', 'value'),
|
|
|
|
|
151 |
prevent_initial_call=True
|
152 |
)
|
153 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
154 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
155 |
+
selected_strategies):
|
156 |
|
157 |
+
output_components = []
|
|
|
158 |
|
159 |
+
if not selected_strategies:
|
160 |
+
return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
|
161 |
|
|
|
162 |
if not all([num_devices, num_stages, num_batches, op_time_forward]):
|
163 |
+
return [dbc.Alert("Missing required basic input values (Devices, Stages, Batches, Forward Time).", color="danger")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
+
for strategy in selected_strategies:
|
166 |
+
error_message = ""
|
167 |
+
fig = go.Figure()
|
168 |
+
placement_strategy = ""
|
169 |
+
|
170 |
+
split_backward = strategy in ["zb1p", "dualpipe"]
|
171 |
+
|
172 |
+
if split_backward and not all([op_time_backward_d, op_time_backward_w]):
|
173 |
+
error_message = f"Strategy '{strategy}': Backward D and Backward W times are required."
|
174 |
+
elif not split_backward and not op_time_backward:
|
175 |
+
error_message = f"Strategy '{strategy}': Combined Backward time is required."
|
176 |
+
|
177 |
+
if not error_message:
|
178 |
+
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
|
179 |
+
placement_strategy = "standard"
|
180 |
+
if num_devices != num_stages:
|
181 |
+
error_message = f"Strategy '{strategy}': Requires Number of Stages == Number of Devices."
|
182 |
+
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
|
183 |
+
placement_strategy = "interleave"
|
184 |
+
if num_stages % num_devices != 0:
|
185 |
+
error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
|
186 |
+
elif strategy == "dualpipe":
|
187 |
+
placement_strategy = "dualpipe"
|
188 |
+
if num_stages % 2 != 0:
|
189 |
+
error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
|
190 |
+
elif num_stages != num_devices:
|
191 |
+
error_message = f"Strategy '{strategy}' (DualPipe): Requires Number of Stages == Number of Devices."
|
192 |
+
|
193 |
+
if not error_message:
|
194 |
+
try:
|
195 |
+
op_times = { "forward": float(op_time_forward) }
|
196 |
+
if split_backward:
|
197 |
+
op_times["backward_D"] = float(op_time_backward_d)
|
198 |
+
op_times["backward_W"] = float(op_time_backward_w)
|
199 |
+
op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w)
|
200 |
+
else:
|
201 |
+
op_times["backward"] = float(op_time_backward)
|
202 |
+
|
203 |
+
config = ScheduleConfig(
|
204 |
+
num_devices=int(num_devices),
|
205 |
+
num_stages=int(num_stages),
|
206 |
+
num_batches=int(num_batches),
|
207 |
+
p2p_latency=float(p2p_latency),
|
208 |
+
placement_strategy=placement_strategy,
|
209 |
+
split_backward=split_backward,
|
210 |
+
op_times=op_times,
|
211 |
+
)
|
212 |
+
|
213 |
+
schedule_func = STRATEGIES.get(strategy)
|
214 |
+
if not schedule_func:
|
215 |
+
raise ValueError(f"Invalid strategy function for: {strategy}")
|
216 |
+
|
217 |
+
schedule = schedule_func(config)
|
218 |
+
schedule.execute()
|
219 |
+
|
220 |
+
vis_data = convert_schedule_to_visualization_format(schedule)
|
221 |
+
fig = create_pipeline_figure(vis_data, show_progress=False)
|
222 |
+
|
223 |
+
output_components.append(html.Div([
|
224 |
+
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
225 |
+
dcc.Graph(figure=fig)
|
226 |
+
]))
|
227 |
+
|
228 |
+
except (AssertionError, ValueError, TypeError) as e:
|
229 |
+
error_message = f"Error generating schedule for '{strategy}': {e}"
|
230 |
+
import traceback
|
231 |
+
traceback.print_exc()
|
232 |
+
except Exception as e:
|
233 |
+
error_message = f"An unexpected error occurred for '{strategy}': {e}"
|
234 |
+
import traceback
|
235 |
+
traceback.print_exc()
|
236 |
+
|
237 |
+
if error_message:
|
238 |
+
output_components.append(
|
239 |
+
dbc.Alert(error_message, color="danger", className="mt-3")
|
240 |
+
)
|
241 |
+
|
242 |
+
return output_components
|
243 |
|
244 |
if __name__ == '__main__':
|
245 |
port = 8050
|