Victarry commited on
Commit
65e77cc
·
1 Parent(s): f140d7b

Update visualizer for 1F1B overlap.

Browse files
Files changed (2) hide show
  1. assets/1f1b_overlap.png +2 -2
  2. src/visualizer.py +122 -86
assets/1f1b_overlap.png CHANGED

Git LFS Details

  • SHA256: d28d6ae5675cc9bfb1d24b3dafef358a70c61c48232c6ef942967bb46b28c587
  • Pointer size: 130 Bytes
  • Size of remote file: 67.7 kB

Git LFS Details

  • SHA256: 6ad0be071da07f4d0d7a239de5ed4df8e4b5fd26dcea95a3103909e8c879488d
  • Pointer size: 130 Bytes
  • Size of remote file: 66.7 kB
src/visualizer.py CHANGED
@@ -213,71 +213,14 @@ def create_pipeline_figure(
213
  sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
214
 
215
  for task in sorted_tasks:
216
- # Determine task color and text color
217
- if task["type"] == "forward":
218
- color = get_color(task["type"], task["stage"], num_devices)
219
- text_color = "white"
220
- name = "Forward"
221
- elif task["type"] == "backward":
222
- color = get_color(task["type"], task["stage"], num_devices)
223
- text_color = "black"
224
- name = "Backward"
225
- elif task["type"] == "backward_D":
226
- color = get_color(task["type"], task["stage"], num_devices)
227
- text_color = "black"
228
- name = "Backward (Grad)"
229
- elif task["type"] == "backward_W":
230
- color = get_color(task["type"], task["stage"], num_devices)
231
- text_color = "black"
232
- name = "Backward (Weight)"
233
- elif task["type"].startswith("overlapped_"):
234
- color = get_color(task["type"], task["stage"], num_devices)
235
- text_color = "white"
236
- name = "Overlapped"
237
- # Create a more descriptive name for the hover text
238
- if "is_overlapped" in task and task["is_overlapped"]:
239
- op_types = [op["type"] for op in task["operations"]]
240
- name = f"Overlapped ({', '.join(op_types)})"
241
- else:
242
- color = empty_color
243
- text_color = "black"
244
- name = "Unknown"
245
-
246
- # Add rectangle for the task
247
- start_time = task["start_time"]
248
- duration = task["duration"]
249
-
250
  # Calculate y positions with no gaps
251
  y_pos = device_idx_reversed * y_spacing
252
-
253
- # Create rectangle using shape (batch-add later)
254
- shapes.append(
255
- dict(
256
- type="rect",
257
- x0=start_time,
258
- y0=y_pos - 0.5,
259
- x1=start_time + duration,
260
- y1=y_pos + 0.5,
261
- line=dict(color="black", width=0.5),
262
- fillcolor=color,
263
- layer="above",
264
- )
265
- )
266
-
267
- # Add batch number text (batch-add later)
268
- annotations.append(
269
- dict(
270
- x=start_time + duration / 2,
271
- y=y_pos,
272
- text=f"{task['batch']}" + ("*" if task.get("is_overlapped", False) else ""),
273
- showarrow=False,
274
- font=dict(color=text_color, size=12, family="Arial, bold"),
275
- )
276
- )
277
-
278
- # Prepare hover data (add traces in batches later)
279
- if task.get("is_overlapped", False):
280
- # Enhanced hover text for overlapped operations
281
  op_details = "<br>".join([
282
  f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
283
  for op in task["operations"]
@@ -288,7 +231,113 @@ def create_pipeline_figure(
288
  f"End: {task['start_time'] + task['duration']:.2f}<br>"
289
  f"Duration: {task['duration']:.2f}"
290
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  hover_text = (
293
  f"Batch: {task['batch']}<br>"
294
  f"Stage: {task['stage']}<br>"
@@ -298,17 +347,17 @@ def create_pipeline_figure(
298
  f"Duration: {task['duration']:.2f}"
299
  )
300
 
301
- hover_traces.append(
302
- dict(
303
- x=[start_time + duration / 2],
304
- y=[y_pos],
305
- mode="markers",
306
- marker=dict(opacity=0), # Invisible marker
307
- hoverinfo="text",
308
- text=hover_text,
309
- showlegend=False,
 
310
  )
311
- )
312
 
313
  # Update progress
314
  if show_progress:
@@ -374,15 +423,6 @@ def create_pipeline_figure(
374
  color=get_color("backward_W", vs * num_devices, num_devices),
375
  )
376
  )
377
-
378
- # Add entry for overlapped operations if they exist
379
- if has_overlapped:
380
- legend_items.append(
381
- dict(
382
- name=f"Overlapped (VS {vs})",
383
- color=get_color("overlapped_", vs * num_devices, num_devices),
384
- )
385
- )
386
 
387
  # If no tasks found, add default legend items
388
  if not legend_items:
@@ -397,10 +437,6 @@ def create_pipeline_figure(
397
  name="Backward Weight (VS 0)",
398
  color=get_color("backward_W", 0, num_devices),
399
  ),
400
- dict(
401
- name="Overlapped (VS 0)",
402
- color=get_color("overlapped_", 0, num_devices),
403
- ),
404
  ]
405
 
406
  for i, item in enumerate(legend_items):
 
213
  sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
214
 
215
  for task in sorted_tasks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # Calculate y positions with no gaps
217
  y_pos = device_idx_reversed * y_spacing
218
+ start_time = task["start_time"]
219
+ duration = task["duration"]
220
+
221
+ # Special handling for overlapped operations
222
+ if task.get("is_overlapped", False) and "operations" in task:
223
+ # Prepare hover text for the entire overlapped operation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  op_details = "<br>".join([
225
  f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
226
  for op in task["operations"]
 
231
  f"End: {task['start_time'] + task['duration']:.2f}<br>"
232
  f"Duration: {task['duration']:.2f}"
233
  )
234
+
235
+ # Add invisible marker for hover info
236
+ hover_traces.append(
237
+ dict(
238
+ x=[start_time + duration / 2],
239
+ y=[y_pos],
240
+ mode="markers",
241
+ marker=dict(opacity=0), # Invisible marker
242
+ hoverinfo="text",
243
+ text=hover_text,
244
+ showlegend=False,
245
+ )
246
+ )
247
+
248
+ # Calculate height of each sub-operation
249
+ sub_height = 1.0 / len(task["operations"])
250
+
251
+ # Add rectangles and annotations for each sub-operation
252
+ for i, sub_op in enumerate(task["operations"]):
253
+ # Determine color for this sub-operation
254
+ color = get_color(sub_op["type"], sub_op["stage"], num_devices)
255
+
256
+ # Calculate y position for this sub-operation
257
+ sub_y_pos_bottom = y_pos - 0.5 + (i * sub_height)
258
+ sub_y_pos_top = sub_y_pos_bottom + sub_height
259
+ sub_y_center = (sub_y_pos_bottom + sub_y_pos_top) / 2
260
+
261
+ # Add rectangle for this sub-operation
262
+ shapes.append(
263
+ dict(
264
+ type="rect",
265
+ x0=start_time,
266
+ y0=sub_y_pos_bottom,
267
+ x1=start_time + duration,
268
+ y1=sub_y_pos_top,
269
+ line=dict(color="black", width=0.5),
270
+ fillcolor=color,
271
+ layer="above",
272
+ )
273
+ )
274
+
275
+ # Add batch number text for this sub-operation
276
+ # Determine text color based on background color
277
+ if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
278
+ text_color = "black"
279
+ else:
280
+ text_color = "white"
281
+
282
+ annotations.append(
283
+ dict(
284
+ x=start_time + duration / 2,
285
+ y=sub_y_center,
286
+ text=f"{sub_op['batch']}",
287
+ showarrow=False,
288
+ font=dict(color=text_color, size=12, family="Arial, bold"),
289
+ )
290
+ )
291
  else:
292
+ # Regular (non-overlapped) operation
293
+ # Determine task color and text color
294
+ if task["type"] == "forward":
295
+ color = get_color(task["type"], task["stage"], num_devices)
296
+ text_color = "white"
297
+ name = "Forward"
298
+ elif task["type"] == "backward":
299
+ color = get_color(task["type"], task["stage"], num_devices)
300
+ text_color = "black"
301
+ name = "Backward"
302
+ elif task["type"] == "backward_D":
303
+ color = get_color(task["type"], task["stage"], num_devices)
304
+ text_color = "black"
305
+ name = "Backward (Grad)"
306
+ elif task["type"] == "backward_W":
307
+ color = get_color(task["type"], task["stage"], num_devices)
308
+ text_color = "black"
309
+ name = "Backward (Weight)"
310
+ else:
311
+ color = empty_color
312
+ text_color = "black"
313
+ name = "Unknown"
314
+
315
+ # Add rectangle for the task
316
+ shapes.append(
317
+ dict(
318
+ type="rect",
319
+ x0=start_time,
320
+ y0=y_pos - 0.5,
321
+ x1=start_time + duration,
322
+ y1=y_pos + 0.5,
323
+ line=dict(color="black", width=0.5),
324
+ fillcolor=color,
325
+ layer="above",
326
+ )
327
+ )
328
+
329
+ # Add batch number text
330
+ annotations.append(
331
+ dict(
332
+ x=start_time + duration / 2,
333
+ y=y_pos,
334
+ text=f"{task['batch']}",
335
+ showarrow=False,
336
+ font=dict(color=text_color, size=12, family="Arial, bold"),
337
+ )
338
+ )
339
+
340
+ # Prepare hover data
341
  hover_text = (
342
  f"Batch: {task['batch']}<br>"
343
  f"Stage: {task['stage']}<br>"
 
347
  f"Duration: {task['duration']:.2f}"
348
  )
349
 
350
+ hover_traces.append(
351
+ dict(
352
+ x=[start_time + duration / 2],
353
+ y=[y_pos],
354
+ mode="markers",
355
+ marker=dict(opacity=0), # Invisible marker
356
+ hoverinfo="text",
357
+ text=hover_text,
358
+ showlegend=False,
359
+ )
360
  )
 
361
 
362
  # Update progress
363
  if show_progress:
 
423
  color=get_color("backward_W", vs * num_devices, num_devices),
424
  )
425
  )
 
 
 
 
 
 
 
 
 
426
 
427
  # If no tasks found, add default legend items
428
  if not legend_items:
 
437
  name="Backward Weight (VS 0)",
438
  color=get_color("backward_W", 0, num_devices),
439
  ),
 
 
 
 
440
  ]
441
 
442
  for i, item in enumerate(legend_items):