Spaces:
Running
Running
Add microbatch_group_size_per_vp_stage to app config button.
Browse files
app.py
CHANGED
@@ -38,6 +38,7 @@ default_values = {
|
|
38 |
"op_time_backward": 2.0,
|
39 |
"strategy": ["1f1b_interleave"],
|
40 |
"op_time_overlapped_fwd_bwd": None,
|
|
|
41 |
}
|
42 |
|
43 |
# Define input groups using dbc components
|
@@ -186,6 +187,20 @@ timing_params_card = dbc.Card(
|
|
186 |
placement="right"
|
187 |
)
|
188 |
], className="mb-3"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
]
|
190 |
)
|
191 |
]),
|
@@ -249,6 +264,7 @@ app.layout = dbc.Container([
|
|
249 |
Output('op_time_backward_d', 'invalid'),
|
250 |
Output('op_time_backward_w', 'invalid'),
|
251 |
Output('op_time_overlapped_fwd_bwd', 'invalid'),
|
|
|
252 |
# Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
|
253 |
# We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
|
254 |
# Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
|
@@ -263,12 +279,13 @@ app.layout = dbc.Container([
|
|
263 |
Input('op_time_backward_d', 'value'),
|
264 |
Input('op_time_backward_w', 'value'),
|
265 |
Input('op_time_overlapped_fwd_bwd', 'value'),
|
|
|
266 |
Input('selected-strategies-store', 'data'), # Validate strategy selection
|
267 |
prevent_initial_call=True # Prevent callback running on page load before user interaction
|
268 |
)
|
269 |
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
270 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
271 |
-
op_time_overlapped_fwd_bwd, selected_strategies):
|
272 |
is_invalid = {
|
273 |
"num_devices": num_devices is None or num_devices < 1,
|
274 |
"num_stages": num_stages is None or num_stages < 1,
|
@@ -279,6 +296,7 @@ def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
|
279 |
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
|
280 |
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
|
281 |
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
|
|
|
282 |
}
|
283 |
|
284 |
# Validate strategy selection
|
@@ -318,6 +336,7 @@ def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
|
318 |
is_invalid["op_time_backward_d"],
|
319 |
is_invalid["op_time_backward_w"],
|
320 |
is_invalid["op_time_overlapped_fwd_bwd"],
|
|
|
321 |
strategy_feedback # Update strategy feedback based on validation
|
322 |
)
|
323 |
|
@@ -361,12 +380,13 @@ app.clientside_callback(
|
|
361 |
State('op_time_backward_d', 'value'),
|
362 |
State('op_time_backward_w', 'value'),
|
363 |
State('op_time_overlapped_fwd_bwd', 'value'),
|
|
|
364 |
State('selected-strategies-store', 'data'),
|
365 |
prevent_initial_call=True
|
366 |
)
|
367 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
368 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
369 |
-
op_time_overlapped_fwd_bwd,
|
370 |
selected_strategies):
|
371 |
|
372 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
@@ -480,6 +500,7 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
480 |
placement_strategy=placement_strategy,
|
481 |
split_backward=split_backward,
|
482 |
op_times=op_times,
|
|
|
483 |
)
|
484 |
|
485 |
schedule_func = STRATEGIES.get(strategy)
|
|
|
38 |
"op_time_backward": 2.0,
|
39 |
"strategy": ["1f1b_interleave"],
|
40 |
"op_time_overlapped_fwd_bwd": None,
|
41 |
+
"microbatch_group_size_per_vp_stage": None,
|
42 |
}
|
43 |
|
44 |
# Define input groups using dbc components
|
|
|
187 |
placement="right"
|
188 |
)
|
189 |
], className="mb-3"),
|
190 |
+
html.Div([
|
191 |
+
html.Div([
|
192 |
+
dbc.Label("Microbatch Group Size per VP Stage", html_for='microbatch_group_size_per_vp_stage', className="form-label d-inline-block me-1"),
|
193 |
+
html.I(className="bi bi-info-circle", id="tooltip-target-microbatch-group", style={"cursor": "pointer"})
|
194 |
+
]),
|
195 |
+
dbc.Input(id='microbatch_group_size_per_vp_stage', type='number', placeholder=f"Defaults to num_devices", min=1, step=1, value=default_values["microbatch_group_size_per_vp_stage"]),
|
196 |
+
dbc.FormText("Used for interleave strategies (1f1b_interleave, 1f1b_interleave_overlap)."),
|
197 |
+
dbc.FormFeedback("Microbatch group size must be a positive integer if specified.", type="invalid", id="feedback-microbatch_group_size_per_vp_stage"),
|
198 |
+
dbc.Tooltip(
|
199 |
+
"Number of microbatches to process per virtual pipeline stage before switching to the next stage. Used primarily with interleave scheduling strategies. Defaults to the number of devices.",
|
200 |
+
target="tooltip-target-microbatch-group",
|
201 |
+
placement="right"
|
202 |
+
)
|
203 |
+
], className="mb-3"),
|
204 |
]
|
205 |
)
|
206 |
]),
|
|
|
264 |
Output('op_time_backward_d', 'invalid'),
|
265 |
Output('op_time_backward_w', 'invalid'),
|
266 |
Output('op_time_overlapped_fwd_bwd', 'invalid'),
|
267 |
+
Output('microbatch_group_size_per_vp_stage', 'invalid'),
|
268 |
# Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
|
269 |
# We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
|
270 |
# Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
|
|
|
279 |
Input('op_time_backward_d', 'value'),
|
280 |
Input('op_time_backward_w', 'value'),
|
281 |
Input('op_time_overlapped_fwd_bwd', 'value'),
|
282 |
+
Input('microbatch_group_size_per_vp_stage', 'value'),
|
283 |
Input('selected-strategies-store', 'data'), # Validate strategy selection
|
284 |
prevent_initial_call=True # Prevent callback running on page load before user interaction
|
285 |
)
|
286 |
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
287 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
288 |
+
op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage, selected_strategies):
|
289 |
is_invalid = {
|
290 |
"num_devices": num_devices is None or num_devices < 1,
|
291 |
"num_stages": num_stages is None or num_stages < 1,
|
|
|
296 |
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
|
297 |
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
|
298 |
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
|
299 |
+
"microbatch_group_size_per_vp_stage": microbatch_group_size_per_vp_stage is not None and (microbatch_group_size_per_vp_stage < 1 or microbatch_group_size_per_vp_stage % 1 != 0),
|
300 |
}
|
301 |
|
302 |
# Validate strategy selection
|
|
|
336 |
is_invalid["op_time_backward_d"],
|
337 |
is_invalid["op_time_backward_w"],
|
338 |
is_invalid["op_time_overlapped_fwd_bwd"],
|
339 |
+
is_invalid["microbatch_group_size_per_vp_stage"],
|
340 |
strategy_feedback # Update strategy feedback based on validation
|
341 |
)
|
342 |
|
|
|
380 |
State('op_time_backward_d', 'value'),
|
381 |
State('op_time_backward_w', 'value'),
|
382 |
State('op_time_overlapped_fwd_bwd', 'value'),
|
383 |
+
State('microbatch_group_size_per_vp_stage', 'value'),
|
384 |
State('selected-strategies-store', 'data'),
|
385 |
prevent_initial_call=True
|
386 |
)
|
387 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
388 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
389 |
+
op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage,
|
390 |
selected_strategies):
|
391 |
|
392 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
|
|
500 |
placement_strategy=placement_strategy,
|
501 |
split_backward=split_backward,
|
502 |
op_times=op_times,
|
503 |
+
microbatch_group_size_per_vp_stage=int(microbatch_group_size_per_vp_stage) if microbatch_group_size_per_vp_stage is not None else None,
|
504 |
)
|
505 |
|
506 |
schedule_func = STRATEGIES.get(strategy)
|