Spaces:
Running
Running
update UI.
Browse files- src/server.py +124 -27
src/server.py
CHANGED
@@ -36,14 +36,13 @@ default_values = {
|
|
36 |
"num_devices": 4,
|
37 |
"num_stages": 8,
|
38 |
"num_batches": 16,
|
39 |
-
"p2p_latency": 0.
|
40 |
"op_time_forward": 1.0,
|
41 |
"op_time_backward_d": 1.0,
|
42 |
"op_time_backward_w": 1.0,
|
43 |
"op_time_backward": 2.0,
|
44 |
"strategy": "1f1b_interleave",
|
45 |
-
"
|
46 |
-
"placement_strategy": "interleave"
|
47 |
}
|
48 |
|
49 |
# Define input groups using dbc components
|
@@ -77,7 +76,7 @@ scheduling_params_card = dbc.Card(
|
|
77 |
dbc.Checklist(
|
78 |
id='strategy-checklist',
|
79 |
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
|
80 |
-
value=
|
81 |
inline=False,
|
82 |
),
|
83 |
], className="mb-3"),
|
@@ -106,6 +105,11 @@ timing_params_card = dbc.Card(
|
|
106 |
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
|
107 |
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
108 |
], className="mb-3"),
|
|
|
|
|
|
|
|
|
|
|
109 |
])
|
110 |
)
|
111 |
|
@@ -147,14 +151,22 @@ app.layout = dbc.Container([
|
|
147 |
State('op_time_backward', 'value'),
|
148 |
State('op_time_backward_d', 'value'),
|
149 |
State('op_time_backward_w', 'value'),
|
|
|
150 |
State('strategy-checklist', 'value'),
|
151 |
prevent_initial_call=True
|
152 |
)
|
153 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
154 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
|
|
155 |
selected_strategies):
|
156 |
|
|
|
|
|
|
|
157 |
output_components = []
|
|
|
|
|
|
|
158 |
|
159 |
if not selected_strategies:
|
160 |
return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
|
@@ -164,8 +176,25 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
164 |
|
165 |
for strategy in selected_strategies:
|
166 |
error_message = ""
|
167 |
-
fig = go.Figure()
|
168 |
placement_strategy = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
split_backward = strategy in ["zb1p", "dualpipe"]
|
171 |
|
@@ -177,32 +206,57 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
177 |
if not error_message:
|
178 |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
|
179 |
placement_strategy = "standard"
|
180 |
-
|
181 |
-
error_message = f"Strategy '{strategy}': Requires Number of Stages == Number of Devices."
|
182 |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
|
183 |
placement_strategy = "interleave"
|
184 |
-
if
|
185 |
error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
|
186 |
elif strategy == "dualpipe":
|
187 |
placement_strategy = "dualpipe"
|
188 |
-
if
|
189 |
error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
|
190 |
-
elif num_stages != num_devices:
|
191 |
-
error_message = f"Strategy '{strategy}' (DualPipe): Requires Number of Stages == Number of Devices."
|
192 |
|
|
|
193 |
if not error_message:
|
194 |
try:
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
if split_backward:
|
197 |
-
op_times["backward_D"] = float(op_time_backward_d)
|
198 |
-
op_times["backward_W"] = float(op_time_backward_w)
|
199 |
-
|
|
|
200 |
else:
|
201 |
-
op_times["backward"] = float(op_time_backward)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
config = ScheduleConfig(
|
204 |
-
num_devices=int(
|
205 |
-
num_stages=int(
|
206 |
num_batches=int(num_batches),
|
207 |
p2p_latency=float(p2p_latency),
|
208 |
placement_strategy=placement_strategy,
|
@@ -217,13 +271,9 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
217 |
schedule = schedule_func(config)
|
218 |
schedule.execute()
|
219 |
|
|
|
220 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
221 |
-
|
222 |
-
|
223 |
-
output_components.append(html.Div([
|
224 |
-
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
225 |
-
dcc.Graph(figure=fig)
|
226 |
-
]))
|
227 |
|
228 |
except (AssertionError, ValueError, TypeError) as e:
|
229 |
error_message = f"Error generating schedule for '{strategy}': {e}"
|
@@ -235,9 +285,56 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
235 |
traceback.print_exc()
|
236 |
|
237 |
if error_message:
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
return output_components
|
243 |
|
|
|
36 |
"num_devices": 4,
|
37 |
"num_stages": 8,
|
38 |
"num_batches": 16,
|
39 |
+
"p2p_latency": 0.0,
|
40 |
"op_time_forward": 1.0,
|
41 |
"op_time_backward_d": 1.0,
|
42 |
"op_time_backward_w": 1.0,
|
43 |
"op_time_backward": 2.0,
|
44 |
"strategy": "1f1b_interleave",
|
45 |
+
"op_time_overlapped_fwd_bwd": None,
|
|
|
46 |
}
|
47 |
|
48 |
# Define input groups using dbc components
|
|
|
76 |
dbc.Checklist(
|
77 |
id='strategy-checklist',
|
78 |
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
|
79 |
+
value=list(STRATEGIES.keys()),
|
80 |
inline=False,
|
81 |
),
|
82 |
], className="mb-3"),
|
|
|
105 |
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
|
106 |
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
107 |
], className="mb-3"),
|
108 |
+
html.Div([
|
109 |
+
dbc.Label("Overlapped Forward+Backward:"),
|
110 |
+
dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Optional: Defaults to Fwd + Bwd times", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
|
111 |
+
dbc.FormText("Specify a custom duration if Forward and Backward ops overlap completely."),
|
112 |
+
], className="mb-3"),
|
113 |
])
|
114 |
)
|
115 |
|
|
|
151 |
State('op_time_backward', 'value'),
|
152 |
State('op_time_backward_d', 'value'),
|
153 |
State('op_time_backward_w', 'value'),
|
154 |
+
State('op_time_overlapped_fwd_bwd', 'value'),
|
155 |
State('strategy-checklist', 'value'),
|
156 |
prevent_initial_call=True
|
157 |
)
|
158 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
159 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
160 |
+
op_time_overlapped_fwd_bwd,
|
161 |
selected_strategies):
|
162 |
|
163 |
+
# Define the desired display order for strategies
|
164 |
+
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
165 |
+
|
166 |
output_components = []
|
167 |
+
valid_results = [] # Store (strategy_name, schedule, vis_data) for valid schedules
|
168 |
+
error_messages = [] # Store (strategy_name, error_message) for errors
|
169 |
+
automatic_adjustments = [] # Store messages about automatic parameter adjustments
|
170 |
|
171 |
if not selected_strategies:
|
172 |
return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
|
|
|
176 |
|
177 |
for strategy in selected_strategies:
|
178 |
error_message = ""
|
|
|
179 |
placement_strategy = ""
|
180 |
+
|
181 |
+
# Use local copies of params that might be adjusted for this strategy
|
182 |
+
current_num_stages = num_stages
|
183 |
+
current_num_devices = num_devices
|
184 |
+
|
185 |
+
# Apply automatic adjustments for dualpipe
|
186 |
+
if strategy == "dualpipe" and num_stages != num_devices:
|
187 |
+
current_num_stages = num_devices # Force num_stages = num_devices for dualpipe
|
188 |
+
automatic_adjustments.append(
|
189 |
+
f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
|
190 |
+
)
|
191 |
+
|
192 |
+
# Apply automatic adjustments for strategies that require num_stages == num_devices
|
193 |
+
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
|
194 |
+
current_num_stages = num_devices
|
195 |
+
automatic_adjustments.append(
|
196 |
+
f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
|
197 |
+
)
|
198 |
|
199 |
split_backward = strategy in ["zb1p", "dualpipe"]
|
200 |
|
|
|
206 |
if not error_message:
|
207 |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
|
208 |
placement_strategy = "standard"
|
209 |
+
# No need to check num_stages == num_devices as we've enforced it above
|
|
|
210 |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
|
211 |
placement_strategy = "interleave"
|
212 |
+
if current_num_stages % current_num_devices != 0:
|
213 |
error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
|
214 |
elif strategy == "dualpipe":
|
215 |
placement_strategy = "dualpipe"
|
216 |
+
if current_num_stages % 2 != 0:
|
217 |
error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
|
|
|
|
|
218 |
|
219 |
+
# Create adjusted operation times based on placement strategy
|
220 |
if not error_message:
|
221 |
try:
|
222 |
+
# Calculate number of stages per device for time adjustment
|
223 |
+
stages_per_device = current_num_stages // current_num_devices
|
224 |
+
|
225 |
+
# Calculate scaling factor - this normalizes operation time by stages per device
|
226 |
+
# For standard placement (1:1 stage:device mapping), this remains 1.0
|
227 |
+
# For interleaved, this scales down the time proportionally
|
228 |
+
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
|
229 |
+
|
230 |
+
if stages_per_device > 1:
|
231 |
+
automatic_adjustments.append(
|
232 |
+
f"Strategy '{strategy}': Operation times scaled by 1/{stages_per_device} to account for {stages_per_device} stages per device."
|
233 |
+
)
|
234 |
+
|
235 |
+
# Apply scaling to operation times
|
236 |
+
op_times = {
|
237 |
+
"forward": float(op_time_forward) * time_scale_factor
|
238 |
+
}
|
239 |
+
|
240 |
if split_backward:
|
241 |
+
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
|
242 |
+
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
|
243 |
+
# Keep combined for compatibility
|
244 |
+
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
|
245 |
else:
|
246 |
+
op_times["backward"] = float(op_time_backward) * time_scale_factor
|
247 |
+
|
248 |
+
if op_time_overlapped_fwd_bwd is not None:
|
249 |
+
try:
|
250 |
+
overlapped_val = float(op_time_overlapped_fwd_bwd)
|
251 |
+
if overlapped_val > 0:
|
252 |
+
# Scale overlapped time too
|
253 |
+
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
|
254 |
+
except (ValueError, TypeError):
|
255 |
+
pass
|
256 |
|
257 |
config = ScheduleConfig(
|
258 |
+
num_devices=int(current_num_devices),
|
259 |
+
num_stages=int(current_num_stages), # Use adjusted value
|
260 |
num_batches=int(num_batches),
|
261 |
p2p_latency=float(p2p_latency),
|
262 |
placement_strategy=placement_strategy,
|
|
|
271 |
schedule = schedule_func(config)
|
272 |
schedule.execute()
|
273 |
|
274 |
+
# Store valid results instead of creating figure immediately
|
275 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
276 |
+
valid_results.append((strategy, schedule, vis_data))
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
except (AssertionError, ValueError, TypeError) as e:
|
279 |
error_message = f"Error generating schedule for '{strategy}': {e}"
|
|
|
285 |
traceback.print_exc()
|
286 |
|
287 |
if error_message:
|
288 |
+
error_messages.append((strategy, error_message))
|
289 |
+
|
290 |
+
# Add alerts for any automatic parameter adjustments
|
291 |
+
for adjustment in automatic_adjustments:
|
292 |
+
output_components.append(
|
293 |
+
dbc.Alert(adjustment, color="info", dismissable=True)
|
294 |
+
)
|
295 |
+
|
296 |
+
# If we have valid results, calculate the maximum execution time across all schedules
|
297 |
+
if valid_results:
|
298 |
+
# Find global maximum execution time
|
299 |
+
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
300 |
+
|
301 |
+
# Sort valid results according to the display order
|
302 |
+
sorted_valid_results = []
|
303 |
+
|
304 |
+
# First add strategies in the predefined order
|
305 |
+
for strategy_name in strategy_display_order:
|
306 |
+
for result in valid_results:
|
307 |
+
if result[0] == strategy_name:
|
308 |
+
sorted_valid_results.append(result)
|
309 |
+
|
310 |
+
# Then add any remaining strategies that might not be in the predefined order
|
311 |
+
for result in valid_results:
|
312 |
+
if result[0] not in strategy_display_order:
|
313 |
+
sorted_valid_results.append(result)
|
314 |
+
|
315 |
+
# Create figures with aligned x-axis, using the sorted results
|
316 |
+
for strategy, _, vis_data in sorted_valid_results:
|
317 |
+
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
318 |
+
|
319 |
+
# Force the x-axis range to be the same for all figures
|
320 |
+
# Add a small margin (5%) for better visualization
|
321 |
+
margin = max_execution_time * 0.05
|
322 |
+
fig.update_layout(
|
323 |
+
xaxis=dict(
|
324 |
+
range=[0, max_execution_time + margin]
|
325 |
+
)
|
326 |
+
)
|
327 |
+
|
328 |
+
output_components.append(html.Div([
|
329 |
+
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
330 |
+
dcc.Graph(figure=fig)
|
331 |
+
]))
|
332 |
+
|
333 |
+
# Add error messages to output
|
334 |
+
for strategy, msg in error_messages:
|
335 |
+
output_components.append(
|
336 |
+
dbc.Alert(msg, color="danger", className="mt-3")
|
337 |
+
)
|
338 |
|
339 |
return output_components
|
340 |
|