Victarry commited on
Commit
bb52925
·
1 Parent(s): 1170f1a

Add 1F1B-overlap implementation.

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. README.md +7 -0
  3. assets/1f1b.png +2 -2
  4. assets/1f1b_overlap.png +3 -0
  5. main.py +43 -10
  6. src/strategies.py +36 -0
  7. src/visualizer.py +225 -165
.gitignore CHANGED
@@ -2,6 +2,7 @@
2
  ./venv
3
  uv.lock
4
  outputs/
 
5
 
6
  # Uncomment below if you want to include these files
7
  # !assets/*.png
 
2
  ./venv
3
  uv.lock
4
  outputs/
5
+ .cursor/*
6
 
7
  # Uncomment below if you want to include these files
8
  # !assets/*.png
README.md CHANGED
@@ -50,6 +50,13 @@ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
50
  ```
51
  ![zb1p](assets/zb1p.png)
52
 
 
 
 
 
 
 
 
53
  ## Configuration
54
 
55
  The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
 
50
  ```
51
  ![zb1p](assets/zb1p.png)
52
 
53
+
54
+ Running for 1F1B-batch-overlap strategy:
55
+ ```bah
56
+ uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
57
+ ```
58
+ ![1f1b_overlap](assets/1f1b_overlap.png)
59
+
60
  ## Configuration
61
 
62
  The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
assets/1f1b.png CHANGED

Git LFS Details

  • SHA256: 1239653e169e1e43a007259fc8712d4c9391ff474d5b8980d5e1eff02b3d79b1
  • Pointer size: 130 Bytes
  • Size of remote file: 64.8 kB

Git LFS Details

  • SHA256: 693988b237729a174a252e277d65c687cf48330ddba575351b5095647507f078
  • Pointer size: 130 Bytes
  • Size of remote file: 67.2 kB
assets/1f1b_overlap.png ADDED

Git LFS Details

  • SHA256: d28d6ae5675cc9bfb1d24b3dafef358a70c61c48232c6ef942967bb46b28c587
  • Pointer size: 130 Bytes
  • Size of remote file: 67.7 kB
main.py CHANGED
@@ -1,5 +1,10 @@
1
  from src.execution_model import ScheduleConfig
2
- from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule, generate_zero_bubble_1p_schedule
 
 
 
 
 
3
  from src.visualizer import visualize_pipeline_parallelism_dash
4
  import hydra
5
  from omegaconf import DictConfig, OmegaConf
@@ -16,6 +21,8 @@ def main(cfg: DictConfig) -> None:
16
  run_interleave(cfg)
17
  elif cfg.strategy == "zb1p":
18
  run_zero_bubble_1p(cfg)
 
 
19
  else:
20
  raise ValueError(f"Unknown strategy: {cfg.strategy}")
21
 
@@ -23,7 +30,9 @@ def main(cfg: DictConfig) -> None:
23
  def run_1f1b(cfg: DictConfig) -> None:
24
  """Run 1F1B pipeline parallelism simulation."""
25
  # Convert OmegaConf to dict for op_times if it exists
26
- op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
 
 
27
 
28
  schedule_config = ScheduleConfig(
29
  num_devices=cfg.num_devices,
@@ -31,7 +40,7 @@ def run_1f1b(cfg: DictConfig) -> None:
31
  num_batches=cfg.num_batches,
32
  p2p_latency=cfg.p2p_latency,
33
  op_times=op_times,
34
- placement_strategy="standard"
35
  )
36
  schedule = generate_1f1b_schedule(schedule_config)
37
  schedule.execute()
@@ -42,15 +51,17 @@ def run_1f1b(cfg: DictConfig) -> None:
42
  def run_interleave(cfg: DictConfig) -> None:
43
  """Run interleaved pipeline parallelism simulation."""
44
  # Convert OmegaConf to dict for op_times if it exists
45
- op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
46
-
 
 
47
  schedule_config = ScheduleConfig(
48
  num_devices=cfg.num_devices,
49
  num_stages=cfg.num_stages,
50
  num_batches=cfg.num_batches,
51
  p2p_latency=cfg.p2p_latency,
52
  placement_strategy="interleave",
53
- op_times=op_times
54
  )
55
  schedule = generate_1f1b_interleave_schedule(schedule_config)
56
  schedule.execute()
@@ -60,20 +71,42 @@ def run_interleave(cfg: DictConfig) -> None:
60
  def run_zero_bubble_1p(cfg: DictConfig) -> None:
61
  """Run zero bubble 1P pipeline parallelism simulation."""
62
  # Convert OmegaConf to dict for op_times if it exists
63
- op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
64
-
 
 
65
  schedule_config = ScheduleConfig(
66
  num_devices=cfg.num_devices,
67
  num_stages=cfg.num_stages,
68
  num_batches=cfg.num_batches,
69
  p2p_latency=cfg.p2p_latency,
70
  op_times=op_times,
71
- split_backward=True
72
  )
73
  schedule = generate_zero_bubble_1p_schedule(schedule_config)
74
  schedule.execute()
75
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if __name__ == "__main__":
79
- main()
 
1
  from src.execution_model import ScheduleConfig
2
+ from src.strategies import (
3
+ generate_1f1b_interleave_schedule,
4
+ generate_1f1b_overlap_schedule,
5
+ generate_1f1b_schedule,
6
+ generate_zero_bubble_1p_schedule,
7
+ )
8
  from src.visualizer import visualize_pipeline_parallelism_dash
9
  import hydra
10
  from omegaconf import DictConfig, OmegaConf
 
21
  run_interleave(cfg)
22
  elif cfg.strategy == "zb1p":
23
  run_zero_bubble_1p(cfg)
24
+ elif cfg.strategy == "1f1b_overlap":
25
+ run_1f1b_overlap(cfg)
26
  else:
27
  raise ValueError(f"Unknown strategy: {cfg.strategy}")
28
 
 
30
  def run_1f1b(cfg: DictConfig) -> None:
31
  """Run 1F1B pipeline parallelism simulation."""
32
  # Convert OmegaConf to dict for op_times if it exists
33
+ op_times = (
34
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
35
+ )
36
 
37
  schedule_config = ScheduleConfig(
38
  num_devices=cfg.num_devices,
 
40
  num_batches=cfg.num_batches,
41
  p2p_latency=cfg.p2p_latency,
42
  op_times=op_times,
43
+ placement_strategy="standard",
44
  )
45
  schedule = generate_1f1b_schedule(schedule_config)
46
  schedule.execute()
 
51
  def run_interleave(cfg: DictConfig) -> None:
52
  """Run interleaved pipeline parallelism simulation."""
53
  # Convert OmegaConf to dict for op_times if it exists
54
+ op_times = (
55
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
56
+ )
57
+
58
  schedule_config = ScheduleConfig(
59
  num_devices=cfg.num_devices,
60
  num_stages=cfg.num_stages,
61
  num_batches=cfg.num_batches,
62
  p2p_latency=cfg.p2p_latency,
63
  placement_strategy="interleave",
64
+ op_times=op_times,
65
  )
66
  schedule = generate_1f1b_interleave_schedule(schedule_config)
67
  schedule.execute()
 
71
  def run_zero_bubble_1p(cfg: DictConfig) -> None:
72
  """Run zero bubble 1P pipeline parallelism simulation."""
73
  # Convert OmegaConf to dict for op_times if it exists
74
+ op_times = (
75
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
76
+ )
77
+
78
  schedule_config = ScheduleConfig(
79
  num_devices=cfg.num_devices,
80
  num_stages=cfg.num_stages,
81
  num_batches=cfg.num_batches,
82
  p2p_latency=cfg.p2p_latency,
83
  op_times=op_times,
84
+ split_backward=True,
85
  )
86
  schedule = generate_zero_bubble_1p_schedule(schedule_config)
87
  schedule.execute()
88
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
89
 
90
 
91
+ def run_1f1b_overlap(cfg: DictConfig) -> None:
92
+ """Run 1F1B overlap pipeline parallelism simulation."""
93
+ # Convert OmegaConf to dict for op_times if it exists
94
+ op_times = (
95
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
96
+ )
97
+
98
+ schedule_config = ScheduleConfig(
99
+ num_devices=cfg.num_devices,
100
+ num_stages=cfg.num_stages,
101
+ num_batches=cfg.num_batches,
102
+ p2p_latency=cfg.p2p_latency,
103
+ op_times=op_times,
104
+ split_backward=False,
105
+ )
106
+ schedule = generate_1f1b_overlap_schedule(schedule_config)
107
+ schedule.execute()
108
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
109
+
110
+
111
  if __name__ == "__main__":
112
+ main()
src/strategies.py CHANGED
@@ -94,6 +94,42 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
94
  return schedule
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # Some codes are copied from Megatron-LM
98
  def generate_1f1b_interleave_schedule(config: ScheduleConfig):
99
  schedule = Schedule(config)
 
94
  return schedule
95
 
96
 
97
+ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
98
+ schedule = Schedule(config)
99
+
100
+ assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
101
+
102
+ for i in range(config.num_devices):
103
+ fwd_batch_id = 0
104
+ bwd_batch_id = 0
105
+ cooldown_batches = warmup_batches = 2 * (config.num_devices - i - 1) + 1
106
+ steady_batches = config.num_batches - warmup_batches
107
+
108
+ for _ in range(warmup_batches):
109
+ schedule.dev_queues[i].add_operation(
110
+ schedule.get_op(fwd_batch_id, i, "forward")
111
+ )
112
+ fwd_batch_id += 1
113
+
114
+ for _ in range(steady_batches):
115
+ schedule.dev_queues[i].add_operation(
116
+ schedule.get_op(fwd_batch_id, i, "forward")
117
+ )
118
+ fwd_batch_id += 1
119
+ schedule.dev_queues[i].add_operation(
120
+ schedule.get_op(bwd_batch_id, i, "backward")
121
+ )
122
+ bwd_batch_id += 1
123
+
124
+ for _ in range(cooldown_batches):
125
+ schedule.dev_queues[i].add_operation(
126
+ schedule.get_op(bwd_batch_id, i, "backward")
127
+ )
128
+ bwd_batch_id += 1
129
+
130
+ return schedule
131
+
132
+
133
  # Some codes are copied from Megatron-LM
134
  def generate_1f1b_interleave_schedule(config: ScheduleConfig):
135
  schedule = Schedule(config)
src/visualizer.py CHANGED
@@ -12,30 +12,34 @@ from src.execution_model import Schedule
12
  def convert_schedule_to_visualization_format(schedule: Schedule):
13
  """
14
  Converts a Schedule object to the format needed for visualization.
15
-
16
  Returns:
17
  Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
18
  """
19
  # Make sure all operations have start and end times
20
  for op in schedule.ops.values():
21
  if op.start_time is None or op.end_time is None:
22
- raise ValueError("Operations must have start and end times. Run ScheduleExecutor.execute() first.")
23
-
 
 
24
  visualization_data = {}
25
-
26
  # Organize operations by device
27
  for device_id, device_queue in enumerate(schedule.dev_queues):
28
  visualization_data[device_id] = []
29
-
30
  for op in device_queue.ops:
31
- visualization_data[device_id].append({
32
- "type": op.op_type,
33
- "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
34
- "stage": op.stage_id,
35
- "start_time": op.start_time,
36
- "duration": op.end_time - op.start_time
37
- })
38
-
 
 
39
  return visualization_data
40
 
41
 
@@ -44,58 +48,58 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
44
  def get_color(op_type: str, stage_id: int, num_devices: int):
45
  # A more harmonious blue palette with better progression for forward operations
46
  forward_colors = [
47
- "#5c88f2", # Periwinkle blue
48
- "#1a53ff", # Deep blue
49
- "#b3c6ff", # Light blue
50
- "#4d79ff", # Strong blue
51
- "#809fff", # Medium blue
52
- "#0039e6", # Rich navy
53
- "#002db3", # Dark navy
54
- "#264db3", # Royal blue
55
- "#7094db", # Steel blue
56
- "#99b3e6" # Pale blue
57
  ]
58
-
59
  # Orange palette for backward operations
60
  backward_colors = [
61
- "#ff9933", # Bright orange
62
- "#ffad5c", # Medium orange
63
- "#ffc285", # Light orange
64
- "#ffd6ad", # Pale orange
65
- "#ff8000", # Deep orange
66
- "#cc6600", # Dark orange
67
- "#ff9933", # Vivid orange
68
- "#ffb366", # Soft orange
69
- "#cc9966", # Muted orange
70
- "#ffd699" # Light amber
71
  ]
72
-
73
  # Improved teal/turquoise palette with better progression for backward_D operations
74
  backward_d_colors = [
75
- "#80ffff", # Light cyan
76
- "#00cccc", # Teal
77
- "#00e6e6", # Bright teal
78
- "#33ffff", # Cyan
79
- "#00b3b3", # Medium teal
80
- "#008080", # Dark teal
81
- "#00e6cc", # Turquoise
82
- "#4ddbbd", # Aqua
83
- "#80d4c8", # Pale teal
84
- "#b3e6e0" # Ice
85
  ]
86
-
87
  # Improved green palette with better progression for backward_W operations
88
  backward_w_colors = [
89
- "#00cc66", # Medium green
90
- "#00e673", # Bright green
91
- "#33ff99", # Mint green
92
- "#80ffbf", # Light green
93
- "#009933", # Forest green
94
- "#006622", # Dark green
95
- "#33cc33", # True green
96
- "#66cc66", # Sage green
97
- "#99cc99", # Pale green
98
- "#c6e6c6" # Pastel green
99
  ]
100
 
101
  virtual_stage = stage_id // num_devices
@@ -115,7 +119,9 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
115
  raise ValueError(f"Invalid operation type: {op_type}")
116
 
117
 
118
- def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True):
 
 
119
  """
120
  Create a Plotly figure for pipeline parallelism scheduling.
121
 
@@ -126,9 +132,9 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
126
  """
127
  # Find the number of devices
128
  num_devices = len(schedule_data)
129
-
130
  empty_color = "whitesmoke"
131
-
132
  # Find the maximum time in the schedule if not provided
133
  if max_time is None:
134
  max_time = 0
@@ -146,7 +152,9 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
146
  tasks_processed = 0
147
 
148
  if show_progress:
149
- progress_bar = tqdm(total=total_tasks + num_devices + 3, desc="Creating visualization")
 
 
150
 
151
  # Create a custom y-axis with no gaps between devices
152
  y_spacing = 1.0 # Use 1.0 for no gaps
@@ -159,7 +167,7 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
159
  # Add rectangles for each task
160
  for device_idx, device in enumerate(schedule_data):
161
  device_idx_reversed = num_devices - device_idx - 1
162
-
163
  # Sort tasks by start time to ensure correct rendering
164
  sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
165
 
@@ -189,44 +197,50 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
189
  # Add rectangle for the task
190
  start_time = task["start_time"]
191
  duration = task["duration"]
192
-
193
  # Calculate y positions with no gaps
194
  y_pos = device_idx_reversed * y_spacing
195
-
196
  # Create rectangle using shape (batch-add later)
197
- shapes.append(dict(
198
- type="rect",
199
- x0=start_time,
200
- y0=y_pos - 0.5,
201
- x1=start_time + duration,
202
- y1=y_pos + 0.5,
203
- line=dict(color="black", width=0.5),
204
- fillcolor=color,
205
- layer="above",
206
- ))
207
-
 
 
208
  # Add batch number text (batch-add later)
209
- annotations.append(dict(
210
- x=start_time + duration / 2,
211
- y=y_pos,
212
- text=f"{task['batch']}",
213
- showarrow=False,
214
- font=dict(color=text_color, size=12, family="Arial, bold"),
215
- ))
216
-
 
 
217
  # Prepare hover data (add traces in batches later)
218
  hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
219
-
220
- hover_traces.append(dict(
221
- x=[start_time + duration / 2],
222
- y=[y_pos],
223
- mode='markers',
224
- marker=dict(opacity=0), # Invisible marker
225
- hoverinfo='text',
226
- text=hover_text,
227
- showlegend=False
228
- ))
229
-
 
 
230
  # Update progress
231
  if show_progress:
232
  tasks_processed += 1
@@ -234,63 +248,83 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
234
 
235
  # Add all shapes at once for better performance
236
  fig.update_layout(shapes=shapes)
237
-
238
  # Add all annotations at once
239
  fig.update_layout(annotations=annotations)
240
-
241
  # Add all hover traces at once
242
  for trace in hover_traces:
243
  fig.add_trace(go.Scatter(**trace))
244
 
245
  # Add custom legend
246
  legend_items = []
247
-
248
  # Find the maximum virtual stage in the data
249
  max_virtual_stage = 0
250
  for device in schedule_data:
251
  for task in schedule_data[device]:
252
  virtual_stage = task["stage"] // num_devices
253
  max_virtual_stage = max(max_virtual_stage, virtual_stage)
254
-
255
  # Add forward and backward items for each virtual stage
256
  for vs in range(max_virtual_stage + 1):
257
- legend_items.append(dict(
258
- name=f"Forward (VS {vs})",
259
- color=get_color("forward", vs * num_devices, num_devices)
260
- ))
261
- legend_items.append(dict(
262
- name=f"Backward (VS {vs})",
263
- color=get_color("backward", vs * num_devices, num_devices)
264
- ))
 
 
 
 
265
  # Add entries for split backward operations if this is a zb1p schedule
266
- if any(task["type"] in ["backward_D", "backward_W"] for device in schedule_data for task in schedule_data[device]):
267
- legend_items.append(dict(
268
- name=f"Backward Grad (VS {vs})",
269
- color=get_color("backward_D", vs * num_devices, num_devices)
270
- ))
271
- legend_items.append(dict(
272
- name=f"Backward Weight (VS {vs})",
273
- color=get_color("backward_W", vs * num_devices, num_devices)
274
- ))
275
-
 
 
 
 
 
 
 
 
276
  # If no tasks found, add default legend items
277
  if not legend_items:
278
  legend_items = [
279
  dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
280
  dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
281
- dict(name="Backward Grad (VS 0)", color=get_color("backward_D", 0, num_devices)),
282
- dict(name="Backward Weight (VS 0)", color=get_color("backward_W", 0, num_devices)),
 
 
 
 
 
 
283
  ]
284
-
285
  for i, item in enumerate(legend_items):
286
- fig.add_trace(go.Scatter(
287
- x=[None],
288
- y=[None],
289
- mode='markers',
290
- marker=dict(size=10, color=item['color']),
291
- name=item['name'],
292
- showlegend=True
293
- ))
 
 
294
  if show_progress and i < len(legend_items) - 1:
295
  progress_bar.update(1)
296
 
@@ -299,11 +333,15 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
299
  # Modify the ordering to put Device 1 at the top, then Device 0, then the rest
300
  if num_devices >= 2:
301
  # Move Device 1 to the top, followed by Device 0
302
- device_labels = [device_labels[1], device_labels[0]] + device_labels[2:] if num_devices > 1 else device_labels
303
-
 
 
 
 
304
  # Calculate tick positions with no gaps
305
  tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
306
-
307
  # Adjust the range to ensure there are no empty spaces at the end
308
  x_end = max_time * 1.05 # Add a small margin
309
 
@@ -323,17 +361,17 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
323
  text=title_text,
324
  x=0.5,
325
  y=0.98, # Move title position closer to the top
326
- font=dict(size=20)
327
  ),
328
  legend=dict(
329
  orientation="v", # Changed from horizontal to vertical
330
  yanchor="top",
331
  y=1.02, # Position at the top
332
  xanchor="right",
333
- x=1.20, # Position further to the right to accommodate more items
334
  title=dict(text="<b>Operation Types:</b>"),
335
  itemsizing="constant",
336
- tracegroupgap=0
337
  ),
338
  width=2000, # Increase width to accommodate the expanded legend
339
  height=400, # Maintain current height
@@ -351,10 +389,13 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
351
  # Cache for storing processed schedule data
352
  _schedule_data_cache = {}
353
 
354
- def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True):
 
 
 
355
  """
356
  Create a Dash app to visualize the pipeline schedule.
357
-
358
  Args:
359
  schedule: Schedule object to visualize
360
  schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
@@ -363,7 +404,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
363
  # Process schedule data only once and cache it
364
  global _schedule_data_cache
365
  cache_key = id(schedule)
366
-
367
  if enable_caching and cache_key in _schedule_data_cache:
368
  schedule_data = _schedule_data_cache[cache_key]
369
  print("Using cached schedule data")
@@ -372,7 +413,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
372
  if enable_caching:
373
  _schedule_data_cache[cache_key] = schedule_data
374
  print("Cached schedule data")
375
-
376
  total_tasks = sum(len(tasks) for tasks in schedule_data.values())
377
  print(f"Total tasks in schedule: {total_tasks}")
378
 
@@ -380,31 +421,48 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
380
  app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
381
 
382
  # Create a more informative layout with data size information
383
- app.layout = html.Div([
384
- html.H1(f"Pipeline Parallelism Visualization - {schedule_type}", style={"textAlign": "center"}),
385
-
386
- html.Div([
387
- html.P(f"Number of devices: {len(schedule_data)}", style={"display": "inline-block", "marginRight": "20px"}),
388
- html.P(f"Total tasks: {total_tasks}", style={"display": "inline-block", "marginRight": "20px"}),
389
- ], style={"marginBottom": "20px"}),
390
-
391
- html.Div(id="graph-container", children=[]),
392
-
393
- dcc.Loading(
394
- id="loading-graph",
395
- type="circle",
396
- children=[
397
- dcc.Graph(
398
- id="pipeline-graph",
399
- config={'displayModeBar': True, 'toImageButtonOptions': {'format': 'png', 'filename': 'pipeline_visualization'}}
400
- ),
401
- ]
402
- ),
403
- ])
404
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  # Cache for storing figure to avoid regenerating it
406
  figure_cache = {}
407
-
408
  @app.callback(
409
  Output("pipeline-graph", "figure"),
410
  Input("graph-container", "children"),
@@ -416,15 +474,15 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
416
  if enable_caching and cache_key in figure_cache:
417
  print("Using cached figure")
418
  return figure_cache[cache_key]
419
-
420
  # Create the figure
421
  figure = create_pipeline_figure(schedule_data, show_progress=True)
422
-
423
  # Cache the figure
424
  if enable_caching:
425
  figure_cache[cache_key] = figure
426
  print("Cached figure")
427
-
428
  return figure
429
 
430
  return app
@@ -435,11 +493,11 @@ def visualize_pipeline_parallelism_dash(
435
  port: int = 8050,
436
  debug: bool = False,
437
  enable_caching: bool = True,
438
- schedule_type="1f1b"
439
  ):
440
  """
441
  Launch a Dash app to visualize the pipeline schedule interactively.
442
-
443
  Args:
444
  schedule: Schedule object to visualize
445
  port: Port to run the Dash app on
@@ -447,6 +505,8 @@ def visualize_pipeline_parallelism_dash(
447
  enable_caching: Whether to cache schedule data and figures
448
  schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
449
  """
450
- app = create_dash_app(schedule, schedule_type=schedule_type, enable_caching=enable_caching)
 
 
451
  print(f"Starting Dash app on http://localhost:{port}/")
452
  app.run_server(debug=debug, port=port)
 
12
  def convert_schedule_to_visualization_format(schedule: Schedule):
13
  """
14
  Converts a Schedule object to the format needed for visualization.
15
+
16
  Returns:
17
  Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
18
  """
19
  # Make sure all operations have start and end times
20
  for op in schedule.ops.values():
21
  if op.start_time is None or op.end_time is None:
22
+ raise ValueError(
23
+ "Operations must have start and end times. Run ScheduleExecutor.execute() first."
24
+ )
25
+
26
  visualization_data = {}
27
+
28
  # Organize operations by device
29
  for device_id, device_queue in enumerate(schedule.dev_queues):
30
  visualization_data[device_id] = []
31
+
32
  for op in device_queue.ops:
33
+ visualization_data[device_id].append(
34
+ {
35
+ "type": op.op_type,
36
+ "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
37
+ "stage": op.stage_id,
38
+ "start_time": op.start_time,
39
+ "duration": op.end_time - op.start_time,
40
+ }
41
+ )
42
+
43
  return visualization_data
44
 
45
 
 
48
  def get_color(op_type: str, stage_id: int, num_devices: int):
49
  # A more harmonious blue palette with better progression for forward operations
50
  forward_colors = [
51
+ "#5c88f2", # Periwinkle blue
52
+ "#1a53ff", # Deep blue
53
+ "#b3c6ff", # Light blue
54
+ "#4d79ff", # Strong blue
55
+ "#809fff", # Medium blue
56
+ "#0039e6", # Rich navy
57
+ "#002db3", # Dark navy
58
+ "#264db3", # Royal blue
59
+ "#7094db", # Steel blue
60
+ "#99b3e6", # Pale blue
61
  ]
62
+
63
  # Orange palette for backward operations
64
  backward_colors = [
65
+ "#ff9933", # Bright orange
66
+ "#ffad5c", # Medium orange
67
+ "#ffc285", # Light orange
68
+ "#ffd6ad", # Pale orange
69
+ "#ff8000", # Deep orange
70
+ "#cc6600", # Dark orange
71
+ "#ff9933", # Vivid orange
72
+ "#ffb366", # Soft orange
73
+ "#cc9966", # Muted orange
74
+ "#ffd699", # Light amber
75
  ]
76
+
77
  # Improved teal/turquoise palette with better progression for backward_D operations
78
  backward_d_colors = [
79
+ "#80ffff", # Light cyan
80
+ "#00cccc", # Teal
81
+ "#00e6e6", # Bright teal
82
+ "#33ffff", # Cyan
83
+ "#00b3b3", # Medium teal
84
+ "#008080", # Dark teal
85
+ "#00e6cc", # Turquoise
86
+ "#4ddbbd", # Aqua
87
+ "#80d4c8", # Pale teal
88
+ "#b3e6e0", # Ice
89
  ]
90
+
91
  # Improved green palette with better progression for backward_W operations
92
  backward_w_colors = [
93
+ "#00cc66", # Medium green
94
+ "#00e673", # Bright green
95
+ "#33ff99", # Mint green
96
+ "#80ffbf", # Light green
97
+ "#009933", # Forest green
98
+ "#006622", # Dark green
99
+ "#33cc33", # True green
100
+ "#66cc66", # Sage green
101
+ "#99cc99", # Pale green
102
+ "#c6e6c6", # Pastel green
103
  ]
104
 
105
  virtual_stage = stage_id // num_devices
 
119
  raise ValueError(f"Invalid operation type: {op_type}")
120
 
121
 
122
+ def create_pipeline_figure(
123
+ schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True
124
+ ):
125
  """
126
  Create a Plotly figure for pipeline parallelism scheduling.
127
 
 
132
  """
133
  # Find the number of devices
134
  num_devices = len(schedule_data)
135
+
136
  empty_color = "whitesmoke"
137
+
138
  # Find the maximum time in the schedule if not provided
139
  if max_time is None:
140
  max_time = 0
 
152
  tasks_processed = 0
153
 
154
  if show_progress:
155
+ progress_bar = tqdm(
156
+ total=total_tasks + num_devices + 3, desc="Creating visualization"
157
+ )
158
 
159
  # Create a custom y-axis with no gaps between devices
160
  y_spacing = 1.0 # Use 1.0 for no gaps
 
167
  # Add rectangles for each task
168
  for device_idx, device in enumerate(schedule_data):
169
  device_idx_reversed = num_devices - device_idx - 1
170
+
171
  # Sort tasks by start time to ensure correct rendering
172
  sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
173
 
 
197
  # Add rectangle for the task
198
  start_time = task["start_time"]
199
  duration = task["duration"]
200
+
201
  # Calculate y positions with no gaps
202
  y_pos = device_idx_reversed * y_spacing
203
+
204
  # Create rectangle using shape (batch-add later)
205
+ shapes.append(
206
+ dict(
207
+ type="rect",
208
+ x0=start_time,
209
+ y0=y_pos - 0.5,
210
+ x1=start_time + duration,
211
+ y1=y_pos + 0.5,
212
+ line=dict(color="black", width=0.5),
213
+ fillcolor=color,
214
+ layer="above",
215
+ )
216
+ )
217
+
218
  # Add batch number text (batch-add later)
219
+ annotations.append(
220
+ dict(
221
+ x=start_time + duration / 2,
222
+ y=y_pos,
223
+ text=f"{task['batch']}",
224
+ showarrow=False,
225
+ font=dict(color=text_color, size=12, family="Arial, bold"),
226
+ )
227
+ )
228
+
229
  # Prepare hover data (add traces in batches later)
230
  hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
231
+
232
+ hover_traces.append(
233
+ dict(
234
+ x=[start_time + duration / 2],
235
+ y=[y_pos],
236
+ mode="markers",
237
+ marker=dict(opacity=0), # Invisible marker
238
+ hoverinfo="text",
239
+ text=hover_text,
240
+ showlegend=False,
241
+ )
242
+ )
243
+
244
  # Update progress
245
  if show_progress:
246
  tasks_processed += 1
 
248
 
249
  # Add all shapes at once for better performance
250
  fig.update_layout(shapes=shapes)
251
+
252
  # Add all annotations at once
253
  fig.update_layout(annotations=annotations)
254
+
255
  # Add all hover traces at once
256
  for trace in hover_traces:
257
  fig.add_trace(go.Scatter(**trace))
258
 
259
  # Add custom legend
260
  legend_items = []
261
+
262
  # Find the maximum virtual stage in the data
263
  max_virtual_stage = 0
264
  for device in schedule_data:
265
  for task in schedule_data[device]:
266
  virtual_stage = task["stage"] // num_devices
267
  max_virtual_stage = max(max_virtual_stage, virtual_stage)
268
+
269
  # Add forward and backward items for each virtual stage
270
  for vs in range(max_virtual_stage + 1):
271
+ legend_items.append(
272
+ dict(
273
+ name=f"Forward (VS {vs})",
274
+ color=get_color("forward", vs * num_devices, num_devices),
275
+ )
276
+ )
277
+ legend_items.append(
278
+ dict(
279
+ name=f"Backward (VS {vs})",
280
+ color=get_color("backward", vs * num_devices, num_devices),
281
+ )
282
+ )
283
  # Add entries for split backward operations if this is a zb1p schedule
284
+ if any(
285
+ task["type"] in ["backward_D", "backward_W"]
286
+ for device in schedule_data
287
+ for task in schedule_data[device]
288
+ ):
289
+ legend_items.append(
290
+ dict(
291
+ name=f"Backward Grad (VS {vs})",
292
+ color=get_color("backward_D", vs * num_devices, num_devices),
293
+ )
294
+ )
295
+ legend_items.append(
296
+ dict(
297
+ name=f"Backward Weight (VS {vs})",
298
+ color=get_color("backward_W", vs * num_devices, num_devices),
299
+ )
300
+ )
301
+
302
  # If no tasks found, add default legend items
303
  if not legend_items:
304
  legend_items = [
305
  dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
306
  dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
307
+ dict(
308
+ name="Backward Grad (VS 0)",
309
+ color=get_color("backward_D", 0, num_devices),
310
+ ),
311
+ dict(
312
+ name="Backward Weight (VS 0)",
313
+ color=get_color("backward_W", 0, num_devices),
314
+ ),
315
  ]
316
+
317
  for i, item in enumerate(legend_items):
318
+ fig.add_trace(
319
+ go.Scatter(
320
+ x=[None],
321
+ y=[None],
322
+ mode="markers",
323
+ marker=dict(size=10, color=item["color"]),
324
+ name=item["name"],
325
+ showlegend=True,
326
+ )
327
+ )
328
  if show_progress and i < len(legend_items) - 1:
329
  progress_bar.update(1)
330
 
 
333
  # Modify the ordering to put Device 1 at the top, then Device 0, then the rest
334
  if num_devices >= 2:
335
  # Move Device 1 to the top, followed by Device 0
336
+ device_labels = (
337
+ [device_labels[1], device_labels[0]] + device_labels[2:]
338
+ if num_devices > 1
339
+ else device_labels
340
+ )
341
+
342
  # Calculate tick positions with no gaps
343
  tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
344
+
345
  # Adjust the range to ensure there are no empty spaces at the end
346
  x_end = max_time * 1.05 # Add a small margin
347
 
 
361
  text=title_text,
362
  x=0.5,
363
  y=0.98, # Move title position closer to the top
364
+ font=dict(size=20),
365
  ),
366
  legend=dict(
367
  orientation="v", # Changed from horizontal to vertical
368
  yanchor="top",
369
  y=1.02, # Position at the top
370
  xanchor="right",
371
+ x=1.20, # Position further to the right to accommodate more items
372
  title=dict(text="<b>Operation Types:</b>"),
373
  itemsizing="constant",
374
+ tracegroupgap=0,
375
  ),
376
  width=2000, # Increase width to accommodate the expanded legend
377
  height=400, # Maintain current height
 
389
  # Cache for storing processed schedule data
390
  _schedule_data_cache = {}
391
 
392
+
393
+ def create_dash_app(
394
+ schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True
395
+ ):
396
  """
397
  Create a Dash app to visualize the pipeline schedule.
398
+
399
  Args:
400
  schedule: Schedule object to visualize
401
  schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
 
404
  # Process schedule data only once and cache it
405
  global _schedule_data_cache
406
  cache_key = id(schedule)
407
+
408
  if enable_caching and cache_key in _schedule_data_cache:
409
  schedule_data = _schedule_data_cache[cache_key]
410
  print("Using cached schedule data")
 
413
  if enable_caching:
414
  _schedule_data_cache[cache_key] = schedule_data
415
  print("Cached schedule data")
416
+
417
  total_tasks = sum(len(tasks) for tasks in schedule_data.values())
418
  print(f"Total tasks in schedule: {total_tasks}")
419
 
 
421
  app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
422
 
423
  # Create a more informative layout with data size information
424
+ app.layout = html.Div(
425
+ [
426
+ html.H1(
427
+ f"Pipeline Parallelism Visualization - {schedule_type}",
428
+ style={"textAlign": "center"},
429
+ ),
430
+ html.Div(
431
+ [
432
+ html.P(
433
+ f"Number of devices: {len(schedule_data)}",
434
+ style={"display": "inline-block", "marginRight": "20px"},
435
+ ),
436
+ html.P(
437
+ f"Total tasks: {total_tasks}",
438
+ style={"display": "inline-block", "marginRight": "20px"},
439
+ ),
440
+ ],
441
+ style={"marginBottom": "20px"},
442
+ ),
443
+ html.Div(id="graph-container", children=[]),
444
+ dcc.Loading(
445
+ id="loading-graph",
446
+ type="circle",
447
+ children=[
448
+ dcc.Graph(
449
+ id="pipeline-graph",
450
+ config={
451
+ "displayModeBar": True,
452
+ "toImageButtonOptions": {
453
+ "format": "png",
454
+ "filename": "pipeline_visualization",
455
+ },
456
+ },
457
+ ),
458
+ ],
459
+ ),
460
+ ]
461
+ )
462
+
463
  # Cache for storing figure to avoid regenerating it
464
  figure_cache = {}
465
+
466
  @app.callback(
467
  Output("pipeline-graph", "figure"),
468
  Input("graph-container", "children"),
 
474
  if enable_caching and cache_key in figure_cache:
475
  print("Using cached figure")
476
  return figure_cache[cache_key]
477
+
478
  # Create the figure
479
  figure = create_pipeline_figure(schedule_data, show_progress=True)
480
+
481
  # Cache the figure
482
  if enable_caching:
483
  figure_cache[cache_key] = figure
484
  print("Cached figure")
485
+
486
  return figure
487
 
488
  return app
 
493
  port: int = 8050,
494
  debug: bool = False,
495
  enable_caching: bool = True,
496
+ schedule_type="1f1b",
497
  ):
498
  """
499
  Launch a Dash app to visualize the pipeline schedule interactively.
500
+
501
  Args:
502
  schedule: Schedule object to visualize
503
  port: Port to run the Dash app on
 
505
  enable_caching: Whether to cache schedule data and figures
506
  schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
507
  """
508
+ app = create_dash_app(
509
+ schedule, schedule_type=schedule_type, enable_caching=enable_caching
510
+ )
511
  print(f"Starting Dash app on http://localhost:{port}/")
512
  app.run_server(debug=debug, port=port)