Victarry commited on
Commit
b220b81
·
1 Parent(s): 912972f

update UI.

Browse files
Files changed (1) hide show
  1. src/server.py +124 -27
src/server.py CHANGED
@@ -36,14 +36,13 @@ default_values = {
36
  "num_devices": 4,
37
  "num_stages": 8,
38
  "num_batches": 16,
39
- "p2p_latency": 0.1,
40
  "op_time_forward": 1.0,
41
  "op_time_backward_d": 1.0,
42
  "op_time_backward_w": 1.0,
43
  "op_time_backward": 2.0,
44
  "strategy": "1f1b_interleave",
45
- "split_backward": False,
46
- "placement_strategy": "interleave"
47
  }
48
 
49
  # Define input groups using dbc components
@@ -77,7 +76,7 @@ scheduling_params_card = dbc.Card(
77
  dbc.Checklist(
78
  id='strategy-checklist',
79
  options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
80
- value=[default_values["strategy"]],
81
  inline=False,
82
  ),
83
  ], className="mb-3"),
@@ -106,6 +105,11 @@ timing_params_card = dbc.Card(
106
  dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
107
  dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
108
  ], className="mb-3"),
 
 
 
 
 
109
  ])
110
  )
111
 
@@ -147,14 +151,22 @@ app.layout = dbc.Container([
147
  State('op_time_backward', 'value'),
148
  State('op_time_backward_d', 'value'),
149
  State('op_time_backward_w', 'value'),
 
150
  State('strategy-checklist', 'value'),
151
  prevent_initial_call=True
152
  )
153
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
154
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
 
155
  selected_strategies):
156
 
 
 
 
157
  output_components = []
 
 
 
158
 
159
  if not selected_strategies:
160
  return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
@@ -164,8 +176,25 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
164
 
165
  for strategy in selected_strategies:
166
  error_message = ""
167
- fig = go.Figure()
168
  placement_strategy = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  split_backward = strategy in ["zb1p", "dualpipe"]
171
 
@@ -177,32 +206,57 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
177
  if not error_message:
178
  if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
179
  placement_strategy = "standard"
180
- if num_devices != num_stages:
181
- error_message = f"Strategy '{strategy}': Requires Number of Stages == Number of Devices."
182
  elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
183
  placement_strategy = "interleave"
184
- if num_stages % num_devices != 0:
185
  error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
186
  elif strategy == "dualpipe":
187
  placement_strategy = "dualpipe"
188
- if num_stages % 2 != 0:
189
  error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
190
- elif num_stages != num_devices:
191
- error_message = f"Strategy '{strategy}' (DualPipe): Requires Number of Stages == Number of Devices."
192
 
 
193
  if not error_message:
194
  try:
195
- op_times = { "forward": float(op_time_forward) }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  if split_backward:
197
- op_times["backward_D"] = float(op_time_backward_d)
198
- op_times["backward_W"] = float(op_time_backward_w)
199
- op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w)
 
200
  else:
201
- op_times["backward"] = float(op_time_backward)
 
 
 
 
 
 
 
 
 
202
 
203
  config = ScheduleConfig(
204
- num_devices=int(num_devices),
205
- num_stages=int(num_stages),
206
  num_batches=int(num_batches),
207
  p2p_latency=float(p2p_latency),
208
  placement_strategy=placement_strategy,
@@ -217,13 +271,9 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
217
  schedule = schedule_func(config)
218
  schedule.execute()
219
 
 
220
  vis_data = convert_schedule_to_visualization_format(schedule)
221
- fig = create_pipeline_figure(vis_data, show_progress=False)
222
-
223
- output_components.append(html.Div([
224
- html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
225
- dcc.Graph(figure=fig)
226
- ]))
227
 
228
  except (AssertionError, ValueError, TypeError) as e:
229
  error_message = f"Error generating schedule for '{strategy}': {e}"
@@ -235,9 +285,56 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
235
  traceback.print_exc()
236
 
237
  if error_message:
238
- output_components.append(
239
- dbc.Alert(error_message, color="danger", className="mt-3")
240
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  return output_components
243
 
 
36
  "num_devices": 4,
37
  "num_stages": 8,
38
  "num_batches": 16,
39
+ "p2p_latency": 0.0,
40
  "op_time_forward": 1.0,
41
  "op_time_backward_d": 1.0,
42
  "op_time_backward_w": 1.0,
43
  "op_time_backward": 2.0,
44
  "strategy": "1f1b_interleave",
45
+ "op_time_overlapped_fwd_bwd": None,
 
46
  }
47
 
48
  # Define input groups using dbc components
 
76
  dbc.Checklist(
77
  id='strategy-checklist',
78
  options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
79
+ value=list(STRATEGIES.keys()),
80
  inline=False,
81
  ),
82
  ], className="mb-3"),
 
105
  dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
106
  dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
107
  ], className="mb-3"),
108
+ html.Div([
109
+ dbc.Label("Overlapped Forward+Backward:"),
110
+ dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Optional: Defaults to Fwd + Bwd times", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
111
+ dbc.FormText("Specify a custom duration if Forward and Backward ops overlap completely."),
112
+ ], className="mb-3"),
113
  ])
114
  )
115
 
 
151
  State('op_time_backward', 'value'),
152
  State('op_time_backward_d', 'value'),
153
  State('op_time_backward_w', 'value'),
154
+ State('op_time_overlapped_fwd_bwd', 'value'),
155
  State('strategy-checklist', 'value'),
156
  prevent_initial_call=True
157
  )
158
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
159
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
160
+ op_time_overlapped_fwd_bwd,
161
  selected_strategies):
162
 
163
+ # Define the desired display order for strategies
164
+ strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
165
+
166
  output_components = []
167
+ valid_results = [] # Store (strategy_name, schedule, vis_data) for valid schedules
168
+ error_messages = [] # Store (strategy_name, error_message) for errors
169
+ automatic_adjustments = [] # Store messages about automatic parameter adjustments
170
 
171
  if not selected_strategies:
172
  return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
 
176
 
177
  for strategy in selected_strategies:
178
  error_message = ""
 
179
  placement_strategy = ""
180
+
181
+ # Use local copies of params that might be adjusted for this strategy
182
+ current_num_stages = num_stages
183
+ current_num_devices = num_devices
184
+
185
+ # Apply automatic adjustments for dualpipe
186
+ if strategy == "dualpipe" and num_stages != num_devices:
187
+ current_num_stages = num_devices # Force num_stages = num_devices for dualpipe
188
+ automatic_adjustments.append(
189
+ f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
190
+ )
191
+
192
+ # Apply automatic adjustments for strategies that require num_stages == num_devices
193
+ if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
194
+ current_num_stages = num_devices
195
+ automatic_adjustments.append(
196
+ f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
197
+ )
198
 
199
  split_backward = strategy in ["zb1p", "dualpipe"]
200
 
 
206
  if not error_message:
207
  if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
208
  placement_strategy = "standard"
209
+ # No need to check num_stages == num_devices as we've enforced it above
 
210
  elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
211
  placement_strategy = "interleave"
212
+ if current_num_stages % current_num_devices != 0:
213
  error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
214
  elif strategy == "dualpipe":
215
  placement_strategy = "dualpipe"
216
+ if current_num_stages % 2 != 0:
217
  error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
 
 
218
 
219
+ # Create adjusted operation times based on placement strategy
220
  if not error_message:
221
  try:
222
+ # Calculate number of stages per device for time adjustment
223
+ stages_per_device = current_num_stages // current_num_devices
224
+
225
+ # Calculate scaling factor - this normalizes operation time by stages per device
226
+ # For standard placement (1:1 stage:device mapping), this remains 1.0
227
+ # For interleaved, this scales down the time proportionally
228
+ time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
229
+
230
+ if stages_per_device > 1:
231
+ automatic_adjustments.append(
232
+ f"Strategy '{strategy}': Operation times scaled by 1/{stages_per_device} to account for {stages_per_device} stages per device."
233
+ )
234
+
235
+ # Apply scaling to operation times
236
+ op_times = {
237
+ "forward": float(op_time_forward) * time_scale_factor
238
+ }
239
+
240
  if split_backward:
241
+ op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
242
+ op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
243
+ # Keep combined for compatibility
244
+ op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
245
  else:
246
+ op_times["backward"] = float(op_time_backward) * time_scale_factor
247
+
248
+ if op_time_overlapped_fwd_bwd is not None:
249
+ try:
250
+ overlapped_val = float(op_time_overlapped_fwd_bwd)
251
+ if overlapped_val > 0:
252
+ # Scale overlapped time too
253
+ op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
254
+ except (ValueError, TypeError):
255
+ pass
256
 
257
  config = ScheduleConfig(
258
+ num_devices=int(current_num_devices),
259
+ num_stages=int(current_num_stages), # Use adjusted value
260
  num_batches=int(num_batches),
261
  p2p_latency=float(p2p_latency),
262
  placement_strategy=placement_strategy,
 
271
  schedule = schedule_func(config)
272
  schedule.execute()
273
 
274
+ # Store valid results instead of creating figure immediately
275
  vis_data = convert_schedule_to_visualization_format(schedule)
276
+ valid_results.append((strategy, schedule, vis_data))
 
 
 
 
 
277
 
278
  except (AssertionError, ValueError, TypeError) as e:
279
  error_message = f"Error generating schedule for '{strategy}': {e}"
 
285
  traceback.print_exc()
286
 
287
  if error_message:
288
+ error_messages.append((strategy, error_message))
289
+
290
+ # Add alerts for any automatic parameter adjustments
291
+ for adjustment in automatic_adjustments:
292
+ output_components.append(
293
+ dbc.Alert(adjustment, color="info", dismissable=True)
294
+ )
295
+
296
+ # If we have valid results, calculate the maximum execution time across all schedules
297
+ if valid_results:
298
+ # Find global maximum execution time
299
+ max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
300
+
301
+ # Sort valid results according to the display order
302
+ sorted_valid_results = []
303
+
304
+ # First add strategies in the predefined order
305
+ for strategy_name in strategy_display_order:
306
+ for result in valid_results:
307
+ if result[0] == strategy_name:
308
+ sorted_valid_results.append(result)
309
+
310
+ # Then add any remaining strategies that might not be in the predefined order
311
+ for result in valid_results:
312
+ if result[0] not in strategy_display_order:
313
+ sorted_valid_results.append(result)
314
+
315
+ # Create figures with aligned x-axis, using the sorted results
316
+ for strategy, _, vis_data in sorted_valid_results:
317
+ fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
318
+
319
+ # Force the x-axis range to be the same for all figures
320
+ # Add a small margin (5%) for better visualization
321
+ margin = max_execution_time * 0.05
322
+ fig.update_layout(
323
+ xaxis=dict(
324
+ range=[0, max_execution_time + margin]
325
+ )
326
+ )
327
+
328
+ output_components.append(html.Div([
329
+ html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
330
+ dcc.Graph(figure=fig)
331
+ ]))
332
+
333
+ # Add error messages to output
334
+ for strategy, msg in error_messages:
335
+ output_components.append(
336
+ dbc.Alert(msg, color="danger", className="mt-3")
337
+ )
338
 
339
  return output_components
340