Victarry commited on
Commit
99a6901
·
1 Parent(s): a9586a0

Add support for DualPipe.

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. README.md +21 -6
  3. assets/dualpipe.png +3 -0
  4. conf/config.yaml +3 -0
  5. main.py +23 -0
  6. src/execution_model.py +81 -19
  7. src/strategies.py +227 -2
  8. src/visualizer.py +2 -12
.gitignore CHANGED
@@ -3,6 +3,7 @@
3
  uv.lock
4
  outputs/
5
  .cursor/*
 
6
 
7
  # Uncomment below if you want to include these files
8
  # !assets/*.png
 
3
  uv.lock
4
  outputs/
5
  .cursor/*
6
+ *.json
7
 
8
  # Uncomment below if you want to include these files
9
  # !assets/*.png
README.md CHANGED
@@ -18,6 +18,7 @@ Pipeline parallelism is a technique used to train large models by partitioning t
18
  - Zero-Bubble 1F1B (ZB-1P)
19
  - 1F1B with computation-communication overlap
20
  - Interleaved 1F1B with computation-communication overlap
 
21
 
22
  - **Visualization**:
23
  - Interactive visualization dashboard using Plotly/Dash
@@ -56,6 +57,12 @@ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
56
  ```
57
  ![zb1p](assets/zb1p.png)
58
 
 
 
 
 
 
 
59
  ### Running for 1F1B-batch-overlap strategy:
60
  ```bash
61
  uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
@@ -68,10 +75,24 @@ uv run python main.py strategy=1f1b_interleave_overlap num_devices=4 num_stages=
68
  ```
69
  ![1f1b_interleave_overlap](assets/1f1b_interleave_overlap.png)
70
 
 
71
  ## Configuration
72
 
73
  The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ### Using Different Configuration Files
76
 
77
  You can use different configuration files with Hydra in several ways:
@@ -90,12 +111,6 @@ You can use different configuration files with Hydra in several ways:
90
  uv run python main.py --config-name=model_A
91
  ```
92
 
93
- #### Override Specific Parameters
94
-
95
- You can also override specific parameters at runtime:
96
- ```bash
97
- uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
98
- ```
99
 
100
  ## Project Structure
101
 
 
18
  - Zero-Bubble 1F1B (ZB-1P)
19
  - 1F1B with computation-communication overlap
20
  - Interleaved 1F1B with computation-communication overlap
21
+ - DualPipe (Bidirectional pipeline parallelism with full forward-backward overlap)
22
 
23
  - **Visualization**:
24
  - Interactive visualization dashboard using Plotly/Dash
 
57
  ```
58
  ![zb1p](assets/zb1p.png)
59
 
60
+ ### Running for DualPipe strategy:
61
+ ```bash
62
+ uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=20
63
+ ```
64
+ ![dualpipe](assets/dualpipe.png)
65
+
66
  ### Running for 1F1B-batch-overlap strategy:
67
  ```bash
68
  uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
 
75
  ```
76
  ![1f1b_interleave_overlap](assets/1f1b_interleave_overlap.png)
77
 
78
+
79
  ## Configuration
80
 
81
  The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
82
 
83
+ #### Override Specific Parameters
84
+
85
+ You can override specific parameters at runtime:
86
+ ```bash
87
+ uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
88
+ ```
89
+
90
+ Use DualPipe as an example, you can manually set different time for forward/backward/backward_D/backward_W/overlapped_forward_backward:
91
+ ```bash
92
+ uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=32 op_times.forward=1.0 op_times.backward=2.0 op_times.backward_D=1.0 op_times.backward_W=1.0 op_times.overlapped_forward_backward=2.5
93
+ ```
94
+
95
+
96
  ### Using Different Configuration Files
97
 
98
  You can use different configuration files with Hydra in several ways:
 
111
  uv run python main.py --config-name=model_A
112
  ```
113
 
 
 
 
 
 
 
114
 
115
  ## Project Structure
116
 
assets/dualpipe.png ADDED

Git LFS Details

  • SHA256: 880f2d4aeed62479216a9e8bc480b22ed2d21b469f661d42c9bc67b1ca6bec2f
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB
conf/config.yaml CHANGED
@@ -11,6 +11,9 @@ op_times:
11
  # Option 1: Simple configuration (same time for all stages)
12
  forward: 1.0
13
  backward: 2.0
 
 
 
14
 
15
  # Option 2: Commented example of stage-specific configuration
16
  # forward:
 
11
  # Option 1: Simple configuration (same time for all stages)
12
  forward: 1.0
13
  backward: 2.0
14
+ backward_D: 1.0
15
+ backward_W: 1.0
16
+ overlapped_forward_backward: 2.0
17
 
18
  # Option 2: Commented example of stage-specific configuration
19
  # forward:
main.py CHANGED
@@ -5,6 +5,7 @@ from src.strategies import (
5
  generate_1f1b_overlap_schedule,
6
  generate_1f1b_schedule,
7
  generate_zero_bubble_1p_schedule,
 
8
  )
9
  from src.visualizer import visualize_pipeline_parallelism_dash
10
  import hydra
@@ -26,6 +27,8 @@ def main(cfg: DictConfig) -> None:
26
  run_1f1b_overlap(cfg)
27
  elif cfg.strategy == "1f1b_interleave_overlap":
28
  run_1f1b_interleave_overlap(cfg)
 
 
29
  else:
30
  raise ValueError(f"Unknown strategy: {cfg.strategy}")
31
 
@@ -129,5 +132,25 @@ def run_1f1b_interleave_overlap(cfg: DictConfig) -> None:
129
  schedule.execute()
130
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if __name__ == "__main__":
133
  main()
 
5
  generate_1f1b_overlap_schedule,
6
  generate_1f1b_schedule,
7
  generate_zero_bubble_1p_schedule,
8
+ generate_dualpipe_schedule,
9
  )
10
  from src.visualizer import visualize_pipeline_parallelism_dash
11
  import hydra
 
27
  run_1f1b_overlap(cfg)
28
  elif cfg.strategy == "1f1b_interleave_overlap":
29
  run_1f1b_interleave_overlap(cfg)
30
+ elif cfg.strategy == "dualpipe":
31
+ run_dualpipe(cfg)
32
  else:
33
  raise ValueError(f"Unknown strategy: {cfg.strategy}")
34
 
 
132
  schedule.execute()
133
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
134
 
135
+ def run_dualpipe(cfg: DictConfig) -> None:
136
+ """Run DualPipe pipeline parallelism simulation."""
137
+ # Convert OmegaConf to dict for op_times if it exists
138
+ op_times = (
139
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
140
+ )
141
+
142
+ schedule_config = ScheduleConfig(
143
+ num_devices=cfg.num_devices,
144
+ num_stages=cfg.num_stages,
145
+ num_batches=cfg.num_batches,
146
+ p2p_latency=cfg.p2p_latency,
147
+ op_times=op_times,
148
+ split_backward=True,
149
+ placement_strategy="dualpipe",
150
+ )
151
+ schedule = generate_dualpipe_schedule(schedule_config)
152
+ schedule.execute()
153
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
154
+
155
  if __name__ == "__main__":
156
  main()
src/execution_model.py CHANGED
@@ -69,7 +69,7 @@ class DeviceQueue:
69
  def add_operation(self, op: Operation):
70
  assert op.stage_id in self.stages
71
  self.ops.append(op)
72
- assert op.device_id is None
73
  op.device_id = self.device_id
74
 
75
 
@@ -97,6 +97,7 @@ class ScheduleConfig:
97
  "forward": 1.0,
98
  "backward_D": 1.0,
99
  "backward_W": 1.0,
 
100
  }
101
  else:
102
  self.op_times = {
@@ -128,9 +129,14 @@ class ScheduleConfig:
128
  self.num_stages_per_device = num_stages // num_devices
129
 
130
  self.init_device_to_stages()
131
- assert (
132
- sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
133
- )
 
 
 
 
 
134
 
135
  def init_device_to_stages(self):
136
  if self.placement_strategy == "standard":
@@ -145,14 +151,27 @@ class ScheduleConfig:
145
  for i in range(self.num_stages):
146
  device_to_put = i % self.num_devices
147
  self.device_to_stages[device_to_put].append(i)
 
 
 
 
 
 
 
148
  else:
149
  raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
150
 
151
  def get_op_time(self, op_type: str, stage_id: int):
152
  # For overlapped operations, extract the original operation types
153
  if op_type.startswith("overlapped_"):
154
- if op_type in self.op_times and self.op_times[op_type][stage_id]:
155
- return self.op_times[op_type][stage_id]
 
 
 
 
 
 
156
  else:
157
  op_parts = op_type.split("_")[1:]
158
  if len(op_parts) >= 2:
@@ -173,20 +192,25 @@ class ScheduleConfig:
173
 
174
 
175
  class Schedule:
176
- def __init__(self, config: ScheduleConfig):
177
  self.ops = {} # (batch_id, stage_id, op_type) -> Operation
178
  self.device_queues: List[DeviceQueue] = []
179
  for dev_id in range(config.num_devices):
180
  self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
181
  self.config = config
182
 
183
- self.init_operations()
 
184
  self.op_to_overlapped = {}
185
 
186
  def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
187
  for op in overlapped_op.operations:
188
  self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
189
  self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
 
 
 
 
190
 
191
  def init_operations(self):
192
  op_types = ["forward", "backward"]
@@ -199,9 +223,12 @@ class Schedule:
199
  batch_id, stage_id, op_type
200
  )
201
 
202
- def get_op(self, batch_id: int, stage_id: int, op_type: str):
203
  if (batch_id, stage_id, op_type) in self.op_to_overlapped:
204
  return self.op_to_overlapped[(batch_id, stage_id, op_type)]
 
 
 
205
  return self.ops[(batch_id, stage_id, op_type)]
206
 
207
  def get_dependencies(self, op: Operation, include_device_dependency=True):
@@ -226,20 +253,55 @@ class Schedule:
226
  if self.config.split_backward:
227
  if op.op_type == "backward_D":
228
  if op.stage_id < self.config.num_stages - 1:
229
- deps.append(
230
- (
231
- self.get_op(op.batch_id, op.stage_id + 1, "backward_D"),
232
- self.config.p2p_latency,
 
 
 
 
 
 
 
 
 
 
233
  )
234
- )
235
  elif op.op_type == "backward_W":
236
  if op.stage_id < self.config.num_stages - 1:
237
- deps.append(
238
- (
239
- self.get_op(op.batch_id, op.stage_id, "backward_D"),
240
- self.config.p2p_latency,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  )
242
- )
243
  else:
244
  if op.op_type == "backward":
245
  if op.stage_id < self.config.num_stages - 1:
 
69
  def add_operation(self, op: Operation):
70
  assert op.stage_id in self.stages
71
  self.ops.append(op)
72
+ assert op.device_id is None, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already has a device id on {op.device_id}"
73
  op.device_id = self.device_id
74
 
75
 
 
97
  "forward": 1.0,
98
  "backward_D": 1.0,
99
  "backward_W": 1.0,
100
+ "backward": 2.0,
101
  }
102
  else:
103
  self.op_times = {
 
129
  self.num_stages_per_device = num_stages // num_devices
130
 
131
  self.init_device_to_stages()
132
+ if self.placement_strategy == "dualpipe":
133
+ assert (
134
+ sum(len(stages) for stages in self.device_to_stages.values()) == num_stages * 2
135
+ )
136
+ else:
137
+ assert (
138
+ sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
139
+ )
140
 
141
  def init_device_to_stages(self):
142
  if self.placement_strategy == "standard":
 
151
  for i in range(self.num_stages):
152
  device_to_put = i % self.num_devices
153
  self.device_to_stages[device_to_put].append(i)
154
+ elif self.placement_strategy == "dualpipe":
155
+ # For DualPipe, each device has two stages
156
+ assert self.num_devices == self.num_stages, "DualPipe requires num_devices == num_stages"
157
+ assert self.num_devices % 2 == 0, "DualPipe requires an even number of devices"
158
+ self.device_to_stages = defaultdict(list)
159
+ for i in range(self.num_stages):
160
+ self.device_to_stages[i] = [i, self.num_stages - i - 1]
161
  else:
162
  raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
163
 
164
  def get_op_time(self, op_type: str, stage_id: int):
165
  # For overlapped operations, extract the original operation types
166
  if op_type.startswith("overlapped_"):
167
+ if op_type in self.op_times:
168
+ if isinstance(self.op_times[op_type], dict):
169
+ if stage_id in self.op_times[op_type]:
170
+ return self.op_times[op_type][stage_id]
171
+ else:
172
+ raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
173
+ else:
174
+ return self.op_times[op_type]
175
  else:
176
  op_parts = op_type.split("_")[1:]
177
  if len(op_parts) >= 2:
 
192
 
193
 
194
  class Schedule:
195
+ def __init__(self, config: ScheduleConfig, init_ops: bool = True):
196
  self.ops = {} # (batch_id, stage_id, op_type) -> Operation
197
  self.device_queues: List[DeviceQueue] = []
198
  for dev_id in range(config.num_devices):
199
  self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
200
  self.config = config
201
 
202
+ if init_ops:
203
+ self.init_operations()
204
  self.op_to_overlapped = {}
205
 
206
  def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
207
  for op in overlapped_op.operations:
208
  self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
209
  self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
210
+
211
+ def register_operation(self, op: Operation):
212
+ assert (op.batch_id, op.stage_id, op.op_type) not in self.ops, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already registered"
213
+ self.ops[(op.batch_id, op.stage_id, op.op_type)] = op
214
 
215
  def init_operations(self):
216
  op_types = ["forward", "backward"]
 
223
  batch_id, stage_id, op_type
224
  )
225
 
226
+ def get_op(self, batch_id: int, stage_id: int, op_type: str, allow_none=False):
227
  if (batch_id, stage_id, op_type) in self.op_to_overlapped:
228
  return self.op_to_overlapped[(batch_id, stage_id, op_type)]
229
+ if allow_none:
230
+ if (batch_id, stage_id, op_type) not in self.ops:
231
+ return None
232
  return self.ops[(batch_id, stage_id, op_type)]
233
 
234
  def get_dependencies(self, op: Operation, include_device_dependency=True):
 
253
  if self.config.split_backward:
254
  if op.op_type == "backward_D":
255
  if op.stage_id < self.config.num_stages - 1:
256
+ op_bwd_d = self.get_op(op.batch_id, op.stage_id + 1, "backward_D", allow_none=True)
257
+ if op_bwd_d is not None:
258
+ deps.append(
259
+ (
260
+ op_bwd_d,
261
+ self.config.p2p_latency,
262
+ )
263
+ )
264
+ else:
265
+ deps.append(
266
+ (
267
+ self.get_op(op.batch_id, op.stage_id + 1, "backward"),
268
+ self.config.p2p_latency,
269
+ )
270
  )
 
271
  elif op.op_type == "backward_W":
272
  if op.stage_id < self.config.num_stages - 1:
273
+ op_bwd_d = self.get_op(op.batch_id, op.stage_id, "backward_D", allow_none=True)
274
+ if op_bwd_d is not None:
275
+ deps.append(
276
+ (
277
+ op_bwd_d,
278
+ self.config.p2p_latency,
279
+ )
280
+ )
281
+ else:
282
+ deps.append(
283
+ (
284
+ self.get_op(op.batch_id, op.stage_id, "backward"),
285
+ self.config.p2p_latency,
286
+ )
287
+ )
288
+ elif op.op_type == "backward":
289
+ if op.stage_id < self.config.num_stages - 1:
290
+ op_bwd = self.get_op(op.batch_id, op.stage_id + 1, "backward", allow_none=True)
291
+ if op_bwd is not None:
292
+ deps.append(
293
+ (
294
+ op_bwd,
295
+ self.config.p2p_latency,
296
+ )
297
+ )
298
+ else:
299
+ deps.append(
300
+ (
301
+ self.get_op(op.batch_id, op.stage_id + 1, "backward_D"),
302
+ self.config.p2p_latency,
303
+ )
304
  )
 
305
  else:
306
  if op.op_type == "backward":
307
  if op.stage_id < self.config.num_stages - 1:
src/strategies.py CHANGED
@@ -1,5 +1,5 @@
1
- from collections import defaultdict
2
- from src.execution_model import OverlappedOperation, Schedule, ScheduleConfig
3
 
4
 
5
  def generate_1f1b_schedule(config: ScheduleConfig):
@@ -43,6 +43,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
43
  schedule = Schedule(config)
44
  total_batches = config.num_batches
45
  assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
 
46
 
47
  for i in range(config.num_devices):
48
  fwd_batch_id = 0
@@ -354,3 +355,227 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
354
 
355
 
356
  return schedule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, deque
2
+ from src.execution_model import OverlappedOperation, Operation, Schedule, ScheduleConfig
3
 
4
 
5
  def generate_1f1b_schedule(config: ScheduleConfig):
 
43
  schedule = Schedule(config)
44
  total_batches = config.num_batches
45
  assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
46
+ assert config.split_backward, "ZB-1P requires split_backward=True"
47
 
48
  for i in range(config.num_devices):
49
  fwd_batch_id = 0
 
355
 
356
 
357
  return schedule
358
+
359
+
360
+ def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2):
361
+ """
362
+ Helper function to create overlapped operations correctly.
363
+ This handles the underlying operation creation and registration to avoid device_id issues.
364
+ """
365
+ # Get the operations from the schedule
366
+ op1 = schedule.ops[(batch_id1, stage_id, type1)]
367
+ op2 = schedule.ops[(batch_id2, stage_id, type2)]
368
+
369
+ # Create the overlapped operation
370
+ overlapped_op = OverlappedOperation([op1, op2])
371
+
372
+ # Register in the schedule to ensure proper tracking
373
+ schedule.register_overlapped_operation(overlapped_op)
374
+
375
+ return overlapped_op
376
+
377
+
378
+ def generate_dualpipe_schedule(config: ScheduleConfig):
379
+ """
380
+ Implements the DualPipe scheduling strategy.
381
+
382
+ DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
383
+ and backward computation-communication phases and reduces pipeline bubbles.
384
+
385
+ The DualPipe strategy has the following characteristics:
386
+ 1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
387
+ 2. Each device handles both a forward stage and a reverse stage
388
+ 3. Overlaps forward and backward operations to reduce bubble size
389
+ 4. Assumes config.num_batches corresponds to half the total microbatches in original paper (M).
390
+ 5. Currently only supports split_backward=True.
391
+
392
+ Args:
393
+ config: The scheduling configuration
394
+
395
+ Returns:
396
+ A Schedule object with the DualPipe scheduling
397
+ """
398
+ # Ensure placement strategy is set for Schedule initialization
399
+ assert config.placement_strategy == "dualpipe", "DualPipe schedule currently only supports placement_strategy='dualpipe'"
400
+ # Assertions based on DualPipe requirements
401
+ assert config.num_stages % 2 == 0, "DualPipe requires an even number of stages (and devices)"
402
+ assert config.num_devices == config.num_stages, "DualPipe requires num_devices == num_stages"
403
+ assert config.num_batches % 2 == 0, "DualPipe requires an even number of microbatches (config.num_batches)"
404
+ # Assertion based on original implementation: num_chunks >= num_ranks * 2
405
+ # Here, M (config.num_batches) corresponds to half_num_chunks
406
+ assert config.num_batches >= config.num_devices, "DualPipe requires config.num_batches >= config.num_devices"
407
+ assert config.split_backward, "DualPipe schedule currently only supports split_backward=True"
408
+
409
+ schedule = Schedule(config, init_ops=False)
410
+
411
+ num_stages = config.num_stages
412
+ num_devices = config.num_devices
413
+ # config.num_batches is M in the original paper, which corresponds to half_num_chunks
414
+ half_num_chunks = config.num_batches // 2
415
+ num_half_ranks = num_devices // 2
416
+
417
+ fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
418
+ bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
419
+
420
+ waited_weight_grad = [deque() for _ in range(num_devices)] # (device_id, ) -> List[(stage_id, batch_id)]
421
+
422
+ for device_id in range(num_devices):
423
+ is_in_second_half = device_id >= num_half_ranks
424
+ if is_in_second_half:
425
+ fwd_batch_ids[device_id, 1] = 0
426
+ fwd_batch_ids[device_id, 0] = config.num_batches // 2
427
+ bwd_d_batch_ids[device_id, 1] = 0
428
+ bwd_d_batch_ids[device_id, 0] = config.num_batches // 2
429
+ else:
430
+ fwd_batch_ids[device_id, 0] = 0
431
+ fwd_batch_ids[device_id, 1] = config.num_batches // 2
432
+ bwd_d_batch_ids[device_id, 0] = 0
433
+ bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
434
+ def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
435
+ stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
436
+ stage_rev_dir = num_stages - 1 - device_id # Stage handled when moving backward (N-1 to 0)
437
+ if not is_in_second_half:
438
+ # First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
439
+ return stage_fwd_dir if phase == 0 else stage_rev_dir
440
+ else:
441
+ # Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
442
+ return stage_rev_dir if phase == 0 else stage_fwd_dir
443
+
444
+
445
+ def add_op_to_queue(device_id, stage_id, op_type, batch_id):
446
+ # Retrieve the correct pre-initialized Operation object
447
+ op = Operation(batch_id, stage_id, op_type)
448
+ schedule.register_operation(op)
449
+ # Add to the device queue
450
+ schedule.device_queues[device_id].add_operation(op)
451
+
452
+ def _schedule_forward_chunk(device_id, phase, is_in_second_half):
453
+ """Schedules a forward compute operation."""
454
+ stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
455
+ batch_id = fwd_batch_ids[device_id, phase]
456
+ add_op_to_queue(device_id, stage_id, "forward", batch_id)
457
+ fwd_batch_ids[device_id, phase] += 1
458
+
459
+ def _schedule_backward_chunk(device_id, phase, is_in_second_half):
460
+ """Schedules a backward_D with backward_W compute operation."""
461
+ stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
462
+ batch_id = bwd_d_batch_ids[device_id, phase]
463
+ add_op_to_queue(device_id, stage_id, "backward", batch_id)
464
+ bwd_d_batch_ids[device_id, phase] += 1
465
+
466
+ def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
467
+ """Schedules a backward_D compute operation."""
468
+ stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
469
+ batch_id = bwd_d_batch_ids[device_id, phase]
470
+ add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
471
+ bwd_d_batch_ids[device_id, phase] += 1
472
+ waited_weight_grad[device_id].append((stage_id, batch_id))
473
+
474
+ def _schedule_backward_weight_chunk(device_id):
475
+ """Schedules a backward_W compute operation."""
476
+ stage_id, batch_id = waited_weight_grad[device_id].popleft()
477
+ add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
478
+
479
+ def _schedule_forward_backward_chunk(device_id, fwd_phase, bwd_phase, is_in_second_half):
480
+ """Schedules an overlapped forward and backward_D compute operation."""
481
+ fwd_stage_id = get_stage_for_phase(device_id, fwd_phase, num_stages, is_in_second_half)
482
+ bwd_stage_id = get_stage_for_phase(device_id, bwd_phase, num_stages, is_in_second_half)
483
+
484
+ fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
485
+
486
+ fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
487
+ schedule.register_operation(fwd_op)
488
+ fwd_batch_ids[device_id, fwd_phase] += 1
489
+
490
+ bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_phase]
491
+ bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
492
+ schedule.register_operation(bwd_op)
493
+ bwd_d_batch_ids[device_id, bwd_phase] += 1
494
+
495
+ # Create and register the overlapped operation
496
+ overlapped_op = OverlappedOperation([fwd_op, bwd_op])
497
+ schedule.register_overlapped_operation(overlapped_op)
498
+
499
+ # Add the overlapped operation to the queue
500
+ schedule.device_queues[device_id].add_operation(overlapped_op)
501
+
502
+
503
+ # Process each device (rank in original code)
504
+ for device_id in range(num_devices):
505
+ half_rank = min(device_id, num_devices - 1 - device_id)
506
+ is_in_second_half = device_id >= num_half_ranks
507
+ is_middle_rank = (device_id == num_half_ranks - 1) or (device_id == num_half_ranks)
508
+
509
+ # Map original steps to operation additions
510
+ # Step 1: nF0
511
+ step_1_count = (num_half_ranks - half_rank - 1) * 2
512
+ for _ in range(step_1_count):
513
+ _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
514
+
515
+ # Step 2: nF0F1
516
+ step_2_count = half_rank + 1
517
+ for i in range(step_2_count):
518
+ _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
519
+ _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
520
+
521
+ # Step 3: nB1W1F1
522
+ step_3_count = num_half_ranks - half_rank - 1
523
+ for _ in range(step_3_count):
524
+ _schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
525
+ _schedule_backward_weight_chunk(device_id,) # W1
526
+ _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
527
+
528
+ # Step 4 (Main step): nF0B1F1B0
529
+ step_4_count = half_num_chunks - num_devices + half_rank + 1
530
+ for i in range(step_4_count):
531
+ # if i == 0 and is_middle_rank:
532
+ # Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
533
+ # _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
534
+ # _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
535
+ # _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
536
+ # else:
537
+ # Overlap F0 and B1_D, then schedule W1
538
+ _schedule_forward_backward_chunk(device_id, 0, 1, is_in_second_half) # F0+B1
539
+
540
+ # Overlap F1 and B0_D, then schedule W0
541
+ _schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
542
+
543
+ # Step 5: nB1F1B0
544
+ step_5_count = num_half_ranks - half_rank - 1
545
+ for _ in range(step_5_count):
546
+ _schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
547
+ _schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
548
+
549
+ # Step 6: nB1B0
550
+ step_6_count = half_rank + 1
551
+ enable_zb = False
552
+ for i in range(step_6_count):
553
+ if i == step_6_count // 2 and half_rank % 2 == 1:
554
+ enable_zb = True
555
+ if enable_zb:
556
+ _schedule_backward_input_chunk(device_id, 1, is_in_second_half)
557
+ else:
558
+ _schedule_backward_chunk(device_id, 1, is_in_second_half)
559
+ if i == step_6_count // 2 and half_rank % 2 == 0:
560
+ enable_zb = True
561
+ if enable_zb:
562
+ _schedule_backward_input_chunk(device_id, 0, is_in_second_half)
563
+ else:
564
+ _schedule_backward_chunk(device_id, 0, is_in_second_half)
565
+
566
+ # Step 7: nWB0
567
+ step_7_count = num_half_ranks - half_rank - 1
568
+ for _ in range(step_7_count):
569
+ _schedule_backward_weight_chunk(device_id) # W1 (use gradient from B1_D scheduled previously)
570
+ _schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
571
+
572
+ # Step 8: nW
573
+ step_8_count = half_rank + 1
574
+ for _ in range(step_8_count):
575
+ # W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
576
+ # W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
577
+ # The last W0 gradients correspond to B0_D from step 6 or 7.
578
+ _schedule_backward_weight_chunk(device_id) # W0 (use gradient from B0_D scheduled previously)
579
+
580
+ return schedule
581
+
src/visualizer.py CHANGED
@@ -89,11 +89,6 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
89
 
90
  # Improved teal/turquoise palette with low saturation and high brightness
91
  backward_d_colors = [
92
- "#ccffff", # Very light cyan
93
- "#b3ffff", # Pale cyan
94
- "#99ffff", # Light cyan
95
- "#80ffff", # Cyan
96
- "#66e6e6", # Soft teal
97
  "#4dcccc", # Light teal
98
  "#33b3b3", # Teal
99
  "#009999", # Medium teal
@@ -102,12 +97,6 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
102
 
103
  # Improved green palette with low saturation and high brightness
104
  backward_w_colors = [
105
- "#ccffe6", # Very light mint
106
- "#b3ffd9", # Pale mint
107
- "#99ffcc", # Light mint
108
- "#80ffbf", # Mint green
109
- "#66e6a6", # Soft green
110
- "#4dcc8c", # Light green
111
  "#33b373", # Medium green
112
  "#009959", # Forest green
113
  "#008040", # Dark green
@@ -162,7 +151,8 @@ def create_pipeline_figure(
162
  max_batch = max(max_batch, task["batch"])
163
 
164
  # Flag to determine whether to show text labels
165
- show_text_labels = max_batch <= 16
 
166
 
167
  # Create a figure
168
  fig = go.Figure()
 
89
 
90
  # Improved teal/turquoise palette with low saturation and high brightness
91
  backward_d_colors = [
 
 
 
 
 
92
  "#4dcccc", # Light teal
93
  "#33b3b3", # Teal
94
  "#009999", # Medium teal
 
97
 
98
  # Improved green palette with low saturation and high brightness
99
  backward_w_colors = [
 
 
 
 
 
 
100
  "#33b373", # Medium green
101
  "#009959", # Forest green
102
  "#008040", # Dark green
 
151
  max_batch = max(max_batch, task["batch"])
152
 
153
  # Flag to determine whether to show text labels
154
+ num_operations_per_device = len(schedule_data[0])
155
+ show_text_labels = num_operations_per_device <= 64
156
 
157
  # Create a figure
158
  fig = go.Figure()