Spaces:
Running
Running
Update visualizer for 1F1B overlap.
Browse files- assets/1f1b_overlap.png +2 -2
- src/visualizer.py +122 -86
assets/1f1b_overlap.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/visualizer.py
CHANGED
@@ -213,71 +213,14 @@ def create_pipeline_figure(
|
|
213 |
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
214 |
|
215 |
for task in sorted_tasks:
|
216 |
-
# Determine task color and text color
|
217 |
-
if task["type"] == "forward":
|
218 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
219 |
-
text_color = "white"
|
220 |
-
name = "Forward"
|
221 |
-
elif task["type"] == "backward":
|
222 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
223 |
-
text_color = "black"
|
224 |
-
name = "Backward"
|
225 |
-
elif task["type"] == "backward_D":
|
226 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
227 |
-
text_color = "black"
|
228 |
-
name = "Backward (Grad)"
|
229 |
-
elif task["type"] == "backward_W":
|
230 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
231 |
-
text_color = "black"
|
232 |
-
name = "Backward (Weight)"
|
233 |
-
elif task["type"].startswith("overlapped_"):
|
234 |
-
color = get_color(task["type"], task["stage"], num_devices)
|
235 |
-
text_color = "white"
|
236 |
-
name = "Overlapped"
|
237 |
-
# Create a more descriptive name for the hover text
|
238 |
-
if "is_overlapped" in task and task["is_overlapped"]:
|
239 |
-
op_types = [op["type"] for op in task["operations"]]
|
240 |
-
name = f"Overlapped ({', '.join(op_types)})"
|
241 |
-
else:
|
242 |
-
color = empty_color
|
243 |
-
text_color = "black"
|
244 |
-
name = "Unknown"
|
245 |
-
|
246 |
-
# Add rectangle for the task
|
247 |
-
start_time = task["start_time"]
|
248 |
-
duration = task["duration"]
|
249 |
-
|
250 |
# Calculate y positions with no gaps
|
251 |
y_pos = device_idx_reversed * y_spacing
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
y0=y_pos - 0.5,
|
259 |
-
x1=start_time + duration,
|
260 |
-
y1=y_pos + 0.5,
|
261 |
-
line=dict(color="black", width=0.5),
|
262 |
-
fillcolor=color,
|
263 |
-
layer="above",
|
264 |
-
)
|
265 |
-
)
|
266 |
-
|
267 |
-
# Add batch number text (batch-add later)
|
268 |
-
annotations.append(
|
269 |
-
dict(
|
270 |
-
x=start_time + duration / 2,
|
271 |
-
y=y_pos,
|
272 |
-
text=f"{task['batch']}" + ("*" if task.get("is_overlapped", False) else ""),
|
273 |
-
showarrow=False,
|
274 |
-
font=dict(color=text_color, size=12, family="Arial, bold"),
|
275 |
-
)
|
276 |
-
)
|
277 |
-
|
278 |
-
# Prepare hover data (add traces in batches later)
|
279 |
-
if task.get("is_overlapped", False):
|
280 |
-
# Enhanced hover text for overlapped operations
|
281 |
op_details = "<br>".join([
|
282 |
f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
|
283 |
for op in task["operations"]
|
@@ -288,7 +231,113 @@ def create_pipeline_figure(
|
|
288 |
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
289 |
f"Duration: {task['duration']:.2f}"
|
290 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
hover_text = (
|
293 |
f"Batch: {task['batch']}<br>"
|
294 |
f"Stage: {task['stage']}<br>"
|
@@ -298,17 +347,17 @@ def create_pipeline_figure(
|
|
298 |
f"Duration: {task['duration']:.2f}"
|
299 |
)
|
300 |
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
|
|
310 |
)
|
311 |
-
)
|
312 |
|
313 |
# Update progress
|
314 |
if show_progress:
|
@@ -374,15 +423,6 @@ def create_pipeline_figure(
|
|
374 |
color=get_color("backward_W", vs * num_devices, num_devices),
|
375 |
)
|
376 |
)
|
377 |
-
|
378 |
-
# Add entry for overlapped operations if they exist
|
379 |
-
if has_overlapped:
|
380 |
-
legend_items.append(
|
381 |
-
dict(
|
382 |
-
name=f"Overlapped (VS {vs})",
|
383 |
-
color=get_color("overlapped_", vs * num_devices, num_devices),
|
384 |
-
)
|
385 |
-
)
|
386 |
|
387 |
# If no tasks found, add default legend items
|
388 |
if not legend_items:
|
@@ -397,10 +437,6 @@ def create_pipeline_figure(
|
|
397 |
name="Backward Weight (VS 0)",
|
398 |
color=get_color("backward_W", 0, num_devices),
|
399 |
),
|
400 |
-
dict(
|
401 |
-
name="Overlapped (VS 0)",
|
402 |
-
color=get_color("overlapped_", 0, num_devices),
|
403 |
-
),
|
404 |
]
|
405 |
|
406 |
for i, item in enumerate(legend_items):
|
|
|
213 |
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
214 |
|
215 |
for task in sorted_tasks:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
# Calculate y positions with no gaps
|
217 |
y_pos = device_idx_reversed * y_spacing
|
218 |
+
start_time = task["start_time"]
|
219 |
+
duration = task["duration"]
|
220 |
+
|
221 |
+
# Special handling for overlapped operations
|
222 |
+
if task.get("is_overlapped", False) and "operations" in task:
|
223 |
+
# Prepare hover text for the entire overlapped operation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
op_details = "<br>".join([
|
225 |
f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
|
226 |
for op in task["operations"]
|
|
|
231 |
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
232 |
f"Duration: {task['duration']:.2f}"
|
233 |
)
|
234 |
+
|
235 |
+
# Add invisible marker for hover info
|
236 |
+
hover_traces.append(
|
237 |
+
dict(
|
238 |
+
x=[start_time + duration / 2],
|
239 |
+
y=[y_pos],
|
240 |
+
mode="markers",
|
241 |
+
marker=dict(opacity=0), # Invisible marker
|
242 |
+
hoverinfo="text",
|
243 |
+
text=hover_text,
|
244 |
+
showlegend=False,
|
245 |
+
)
|
246 |
+
)
|
247 |
+
|
248 |
+
# Calculate height of each sub-operation
|
249 |
+
sub_height = 1.0 / len(task["operations"])
|
250 |
+
|
251 |
+
# Add rectangles and annotations for each sub-operation
|
252 |
+
for i, sub_op in enumerate(task["operations"]):
|
253 |
+
# Determine color for this sub-operation
|
254 |
+
color = get_color(sub_op["type"], sub_op["stage"], num_devices)
|
255 |
+
|
256 |
+
# Calculate y position for this sub-operation
|
257 |
+
sub_y_pos_bottom = y_pos - 0.5 + (i * sub_height)
|
258 |
+
sub_y_pos_top = sub_y_pos_bottom + sub_height
|
259 |
+
sub_y_center = (sub_y_pos_bottom + sub_y_pos_top) / 2
|
260 |
+
|
261 |
+
# Add rectangle for this sub-operation
|
262 |
+
shapes.append(
|
263 |
+
dict(
|
264 |
+
type="rect",
|
265 |
+
x0=start_time,
|
266 |
+
y0=sub_y_pos_bottom,
|
267 |
+
x1=start_time + duration,
|
268 |
+
y1=sub_y_pos_top,
|
269 |
+
line=dict(color="black", width=0.5),
|
270 |
+
fillcolor=color,
|
271 |
+
layer="above",
|
272 |
+
)
|
273 |
+
)
|
274 |
+
|
275 |
+
# Add batch number text for this sub-operation
|
276 |
+
# Determine text color based on background color
|
277 |
+
if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
|
278 |
+
text_color = "black"
|
279 |
+
else:
|
280 |
+
text_color = "white"
|
281 |
+
|
282 |
+
annotations.append(
|
283 |
+
dict(
|
284 |
+
x=start_time + duration / 2,
|
285 |
+
y=sub_y_center,
|
286 |
+
text=f"{sub_op['batch']}",
|
287 |
+
showarrow=False,
|
288 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
289 |
+
)
|
290 |
+
)
|
291 |
else:
|
292 |
+
# Regular (non-overlapped) operation
|
293 |
+
# Determine task color and text color
|
294 |
+
if task["type"] == "forward":
|
295 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
296 |
+
text_color = "white"
|
297 |
+
name = "Forward"
|
298 |
+
elif task["type"] == "backward":
|
299 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
300 |
+
text_color = "black"
|
301 |
+
name = "Backward"
|
302 |
+
elif task["type"] == "backward_D":
|
303 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
304 |
+
text_color = "black"
|
305 |
+
name = "Backward (Grad)"
|
306 |
+
elif task["type"] == "backward_W":
|
307 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
308 |
+
text_color = "black"
|
309 |
+
name = "Backward (Weight)"
|
310 |
+
else:
|
311 |
+
color = empty_color
|
312 |
+
text_color = "black"
|
313 |
+
name = "Unknown"
|
314 |
+
|
315 |
+
# Add rectangle for the task
|
316 |
+
shapes.append(
|
317 |
+
dict(
|
318 |
+
type="rect",
|
319 |
+
x0=start_time,
|
320 |
+
y0=y_pos - 0.5,
|
321 |
+
x1=start_time + duration,
|
322 |
+
y1=y_pos + 0.5,
|
323 |
+
line=dict(color="black", width=0.5),
|
324 |
+
fillcolor=color,
|
325 |
+
layer="above",
|
326 |
+
)
|
327 |
+
)
|
328 |
+
|
329 |
+
# Add batch number text
|
330 |
+
annotations.append(
|
331 |
+
dict(
|
332 |
+
x=start_time + duration / 2,
|
333 |
+
y=y_pos,
|
334 |
+
text=f"{task['batch']}",
|
335 |
+
showarrow=False,
|
336 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
337 |
+
)
|
338 |
+
)
|
339 |
+
|
340 |
+
# Prepare hover data
|
341 |
hover_text = (
|
342 |
f"Batch: {task['batch']}<br>"
|
343 |
f"Stage: {task['stage']}<br>"
|
|
|
347 |
f"Duration: {task['duration']:.2f}"
|
348 |
)
|
349 |
|
350 |
+
hover_traces.append(
|
351 |
+
dict(
|
352 |
+
x=[start_time + duration / 2],
|
353 |
+
y=[y_pos],
|
354 |
+
mode="markers",
|
355 |
+
marker=dict(opacity=0), # Invisible marker
|
356 |
+
hoverinfo="text",
|
357 |
+
text=hover_text,
|
358 |
+
showlegend=False,
|
359 |
+
)
|
360 |
)
|
|
|
361 |
|
362 |
# Update progress
|
363 |
if show_progress:
|
|
|
423 |
color=get_color("backward_W", vs * num_devices, num_devices),
|
424 |
)
|
425 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
# If no tasks found, add default legend items
|
428 |
if not legend_items:
|
|
|
437 |
name="Backward Weight (VS 0)",
|
438 |
color=get_color("backward_W", 0, num_devices),
|
439 |
),
|
|
|
|
|
|
|
|
|
440 |
]
|
441 |
|
442 |
for i, item in enumerate(legend_items):
|