Victarry commited on
Commit
6b14326
·
1 Parent(s): e3ec3bb

Add microbatch_group_size_per_vp_stage to app config button.

Browse files
Files changed (1) hide show
  1. app.py +23 -2
app.py CHANGED
@@ -38,6 +38,7 @@ default_values = {
38
  "op_time_backward": 2.0,
39
  "strategy": ["1f1b_interleave"],
40
  "op_time_overlapped_fwd_bwd": None,
 
41
  }
42
 
43
  # Define input groups using dbc components
@@ -186,6 +187,20 @@ timing_params_card = dbc.Card(
186
  placement="right"
187
  )
188
  ], className="mb-3"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  ]
190
  )
191
  ]),
@@ -249,6 +264,7 @@ app.layout = dbc.Container([
249
  Output('op_time_backward_d', 'invalid'),
250
  Output('op_time_backward_w', 'invalid'),
251
  Output('op_time_overlapped_fwd_bwd', 'invalid'),
 
252
  # Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
253
  # We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
254
  # Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
@@ -263,12 +279,13 @@ app.layout = dbc.Container([
263
  Input('op_time_backward_d', 'value'),
264
  Input('op_time_backward_w', 'value'),
265
  Input('op_time_overlapped_fwd_bwd', 'value'),
 
266
  Input('selected-strategies-store', 'data'), # Validate strategy selection
267
  prevent_initial_call=True # Prevent callback running on page load before user interaction
268
  )
269
  def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
270
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
271
- op_time_overlapped_fwd_bwd, selected_strategies):
272
  is_invalid = {
273
  "num_devices": num_devices is None or num_devices < 1,
274
  "num_stages": num_stages is None or num_stages < 1,
@@ -279,6 +296,7 @@ def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
279
  "op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
280
  "op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
281
  "op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
 
282
  }
283
 
284
  # Validate strategy selection
@@ -318,6 +336,7 @@ def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
318
  is_invalid["op_time_backward_d"],
319
  is_invalid["op_time_backward_w"],
320
  is_invalid["op_time_overlapped_fwd_bwd"],
 
321
  strategy_feedback # Update strategy feedback based on validation
322
  )
323
 
@@ -361,12 +380,13 @@ app.clientside_callback(
361
  State('op_time_backward_d', 'value'),
362
  State('op_time_backward_w', 'value'),
363
  State('op_time_overlapped_fwd_bwd', 'value'),
 
364
  State('selected-strategies-store', 'data'),
365
  prevent_initial_call=True
366
  )
367
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
368
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
369
- op_time_overlapped_fwd_bwd,
370
  selected_strategies):
371
 
372
  strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
@@ -480,6 +500,7 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
480
  placement_strategy=placement_strategy,
481
  split_backward=split_backward,
482
  op_times=op_times,
 
483
  )
484
 
485
  schedule_func = STRATEGIES.get(strategy)
 
38
  "op_time_backward": 2.0,
39
  "strategy": ["1f1b_interleave"],
40
  "op_time_overlapped_fwd_bwd": None,
41
+ "microbatch_group_size_per_vp_stage": None,
42
  }
43
 
44
  # Define input groups using dbc components
 
187
  placement="right"
188
  )
189
  ], className="mb-3"),
190
+ html.Div([
191
+ html.Div([
192
+ dbc.Label("Microbatch Group Size per VP Stage", html_for='microbatch_group_size_per_vp_stage', className="form-label d-inline-block me-1"),
193
+ html.I(className="bi bi-info-circle", id="tooltip-target-microbatch-group", style={"cursor": "pointer"})
194
+ ]),
195
+ dbc.Input(id='microbatch_group_size_per_vp_stage', type='number', placeholder=f"Defaults to num_devices", min=1, step=1, value=default_values["microbatch_group_size_per_vp_stage"]),
196
+ dbc.FormText("Used for interleave strategies (1f1b_interleave, 1f1b_interleave_overlap)."),
197
+ dbc.FormFeedback("Microbatch group size must be a positive integer if specified.", type="invalid", id="feedback-microbatch_group_size_per_vp_stage"),
198
+ dbc.Tooltip(
199
+ "Number of microbatches to process per virtual pipeline stage before switching to the next stage. Used primarily with interleave scheduling strategies. Defaults to the number of devices.",
200
+ target="tooltip-target-microbatch-group",
201
+ placement="right"
202
+ )
203
+ ], className="mb-3"),
204
  ]
205
  )
206
  ]),
 
264
  Output('op_time_backward_d', 'invalid'),
265
  Output('op_time_backward_w', 'invalid'),
266
  Output('op_time_overlapped_fwd_bwd', 'invalid'),
267
+ Output('microbatch_group_size_per_vp_stage', 'invalid'),
268
  # Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
269
  # We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
270
  # Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
 
279
  Input('op_time_backward_d', 'value'),
280
  Input('op_time_backward_w', 'value'),
281
  Input('op_time_overlapped_fwd_bwd', 'value'),
282
+ Input('microbatch_group_size_per_vp_stage', 'value'),
283
  Input('selected-strategies-store', 'data'), # Validate strategy selection
284
  prevent_initial_call=True # Prevent callback running on page load before user interaction
285
  )
286
  def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
287
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
288
+ op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage, selected_strategies):
289
  is_invalid = {
290
  "num_devices": num_devices is None or num_devices < 1,
291
  "num_stages": num_stages is None or num_stages < 1,
 
296
  "op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
297
  "op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
298
  "op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
299
+ "microbatch_group_size_per_vp_stage": microbatch_group_size_per_vp_stage is not None and (microbatch_group_size_per_vp_stage < 1 or microbatch_group_size_per_vp_stage % 1 != 0),
300
  }
301
 
302
  # Validate strategy selection
 
336
  is_invalid["op_time_backward_d"],
337
  is_invalid["op_time_backward_w"],
338
  is_invalid["op_time_overlapped_fwd_bwd"],
339
+ is_invalid["microbatch_group_size_per_vp_stage"],
340
  strategy_feedback # Update strategy feedback based on validation
341
  )
342
 
 
380
  State('op_time_backward_d', 'value'),
381
  State('op_time_backward_w', 'value'),
382
  State('op_time_overlapped_fwd_bwd', 'value'),
383
+ State('microbatch_group_size_per_vp_stage', 'value'),
384
  State('selected-strategies-store', 'data'),
385
  prevent_initial_call=True
386
  )
387
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
388
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
389
+ op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage,
390
  selected_strategies):
391
 
392
  strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
 
500
  placement_strategy=placement_strategy,
501
  split_backward=split_backward,
502
  op_times=op_times,
503
+ microbatch_group_size_per_vp_stage=int(microbatch_group_size_per_vp_stage) if microbatch_group_size_per_vp_stage is not None else None,
504
  )
505
 
506
  schedule_func = STRATEGIES.get(strategy)