Victarry commited on
Commit
dc262c1
·
1 Parent(s): 9d44c5d

Add DualPipe-V support.

Browse files
Files changed (5) hide show
  1. README.md +6 -0
  2. assets/dualpipe_v.png +3 -0
  3. main.py +22 -0
  4. src/execution_model.py +7 -0
  5. src/strategies.py +340 -78
README.md CHANGED
@@ -84,6 +84,12 @@ uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=2
84
  ```
85
  ![dualpipe](assets/dualpipe.png)
86
 
 
 
 
 
 
 
87
  ### Running for 1F1B-batch-overlap strategy:
88
  ```bash
89
  uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
 
84
  ```
85
  ![dualpipe](assets/dualpipe.png)
86
 
87
+ ### Running for DualPipe-V strategy
88
+ ```bash
89
+ uv run python main.py strategy=dualpipe_v num_devices=4 num_stages=8 num_batches=10
90
+ ```
91
+ ![dualpipe_v](assets/dualpipe_v.png)
92
+
93
  ### Running for 1F1B-batch-overlap strategy:
94
  ```bash
95
  uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
assets/dualpipe_v.png ADDED

Git LFS Details

  • SHA256: 4bf2f73cfadb14e403166ecf53f6629fd509c32a6096fb53a0f774d79937c65f
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
main.py CHANGED
@@ -4,6 +4,7 @@ from src.strategies import (
4
  generate_1f1b_interleave_schedule,
5
  generate_1f1b_overlap_schedule,
6
  generate_1f1b_schedule,
 
7
  generate_zero_bubble_1p_schedule,
8
  generate_dualpipe_schedule,
9
  )
@@ -29,6 +30,8 @@ def main(cfg: DictConfig) -> None:
29
  run_1f1b_interleave_overlap(cfg)
30
  elif cfg.strategy == "dualpipe":
31
  run_dualpipe(cfg)
 
 
32
  else:
33
  raise ValueError(f"Unknown strategy: {cfg.strategy}")
34
 
@@ -152,5 +155,24 @@ def run_dualpipe(cfg: DictConfig) -> None:
152
  schedule.execute()
153
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if __name__ == "__main__":
156
  main()
 
4
  generate_1f1b_interleave_schedule,
5
  generate_1f1b_overlap_schedule,
6
  generate_1f1b_schedule,
7
+ generate_dualpipe_v_schedule,
8
  generate_zero_bubble_1p_schedule,
9
  generate_dualpipe_schedule,
10
  )
 
30
  run_1f1b_interleave_overlap(cfg)
31
  elif cfg.strategy == "dualpipe":
32
  run_dualpipe(cfg)
33
+ elif cfg.strategy == "dualpipe_v":
34
+ run_dualpipe_v(cfg)
35
  else:
36
  raise ValueError(f"Unknown strategy: {cfg.strategy}")
37
 
 
155
  schedule.execute()
156
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
157
 
158
+ def run_dualpipe_v(cfg: DictConfig) -> None:
159
+ """Run DualPipeV pipeline parallelism simulation."""
160
+ # Convert OmegaConf to dict for op_times if it exists
161
+ op_times = (
162
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
163
+ )
164
+ schedule_config = ScheduleConfig(
165
+ num_devices=cfg.num_devices,
166
+ num_stages=cfg.num_stages,
167
+ num_batches=cfg.num_batches,
168
+ p2p_latency=cfg.p2p_latency,
169
+ op_times=op_times,
170
+ split_backward=True,
171
+ placement_strategy="dualpipe_v",
172
+ )
173
+ schedule = generate_dualpipe_v_schedule(schedule_config)
174
+ schedule.execute()
175
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
176
+
177
  if __name__ == "__main__":
178
  main()
src/execution_model.py CHANGED
@@ -158,6 +158,13 @@ class ScheduleConfig:
158
  self.device_to_stages = defaultdict(list)
159
  for i in range(self.num_stages):
160
  self.device_to_stages[i] = [i, self.num_stages - i - 1]
 
 
 
 
 
 
 
161
  else:
162
  raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
163
 
 
158
  self.device_to_stages = defaultdict(list)
159
  for i in range(self.num_stages):
160
  self.device_to_stages[i] = [i, self.num_stages - i - 1]
161
+ elif self.placement_strategy == "dualpipe_v":
162
+ assert self.num_devices % 2 == 0, "DualPipe-V requires an even number of devices"
163
+ assert self.num_stages == self.num_devices * 2, "DualPipe-V requires num_stages == num_devices * 2"
164
+ assert self.split_backward, "DualPipe-V requires split_backward=True"
165
+ self.device_to_stages = defaultdict(list)
166
+ for i in range(self.num_devices):
167
+ self.device_to_stages[i] = [i, self.num_stages - i - 1]
168
  else:
169
  raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
170
 
src/strategies.py CHANGED
@@ -5,7 +5,9 @@ from src.execution_model import OverlappedOperation, Operation, Schedule, Schedu
5
  def generate_1f1b_schedule(config: ScheduleConfig):
6
  schedule = Schedule(config)
7
 
8
- assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
 
 
9
 
10
  for i in range(config.num_devices):
11
  fwd_batch_id = 0
@@ -42,7 +44,9 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
42
  # Create a new schedule with split_backward=True to support backward_D and backward_W operations
43
  schedule = Schedule(config)
44
  total_batches = config.num_batches
45
- assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
 
 
46
  assert config.split_backward, "ZB-1P requires split_backward=True"
47
 
48
  for i in range(config.num_devices):
@@ -73,7 +77,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
73
  bwd_w_batch_id += 1
74
  bwd_d_batch_id += 1
75
  fwd_batch_id += 1
76
-
77
  for _ in range(cooldown_batches):
78
  schedule.device_queues[i].add_operation(
79
  schedule.get_op(bwd_d_batch_id, i, "backward_D")
@@ -85,7 +89,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
85
 
86
  bwd_w_batch_id += 1
87
  bwd_d_batch_id += 1
88
-
89
  while bwd_w_batch_id < total_batches:
90
  schedule.device_queues[i].add_operation(
91
  schedule.get_op(bwd_w_batch_id, i, "backward_W")
@@ -98,7 +102,9 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
98
  def generate_1f1b_overlap_schedule(config: ScheduleConfig):
99
  schedule = Schedule(config)
100
 
101
- assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
 
 
102
 
103
  for i in range(config.num_devices):
104
  fwd_batch_id = 0
@@ -132,11 +138,11 @@ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
132
 
133
 
134
  def _get_pp_rank_microbatches(
135
- num_microbatches,
136
  num_devices,
137
  device_id,
138
- num_stages_per_device,
139
- microbatch_group_size_per_vp_stage,
140
  ):
141
  """Get the number of total, warmup, and remaining microbatches in PP scheduling."""
142
  total_num_microbatches = num_microbatches * num_stages_per_device
@@ -147,7 +153,9 @@ def _get_pp_rank_microbatches(
147
  # stage ID (more forward passes for earlier stages, later stages can
148
  # immediately start with 1F1B).
149
  num_warmup_microbatches = (num_devices - device_id - 1) * 2
150
- num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
 
 
151
  else:
152
  # forward_backward_no_pipelining
153
  num_warmup_microbatches = 1
@@ -158,27 +166,34 @@ def _get_pp_rank_microbatches(
158
  return num_warmup_microbatches
159
 
160
 
161
- def _get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
 
 
162
  """Get the schedule table for PP scheduling.
163
 
164
  Create a tunable schedule lookup table.
165
- The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
166
  For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
167
  virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
168
  microbatch_id | 0 1 2 0 1 2 3 4 3 4
169
- model_chunk_id | 0 0 0 1 1 1 0 0 1 1
170
  """
171
  schedule_table = []
172
  for min_microbatch_id_in_group in range(
173
  0, num_microbatches, microbatch_group_size_per_vp_stage
174
  ):
175
- if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
 
 
 
176
  # Construct schedule for the last microbatch group
177
  schedule_table.extend(
178
  [
179
  (microbatch_id, model_chunk_id)
180
  for model_chunk_id in range(num_model_chunks)
181
- for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
 
 
182
  ]
183
  )
184
  else:
@@ -196,7 +211,9 @@ def _get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_siz
196
  return schedule_table
197
 
198
 
199
- def _convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
 
 
200
  """Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
201
  order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
202
  virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
@@ -225,7 +242,7 @@ def _convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks,
225
  # Some codes are copied from Megatron-LM
226
  def generate_1f1b_interleave_schedule(config: ScheduleConfig):
227
  schedule = Schedule(config)
228
-
229
  for device_id in range(config.num_devices):
230
  microbatch_group_size_per_vp_stage = config.num_devices
231
  num_warmup_microbatches = _get_pp_rank_microbatches(
@@ -244,25 +261,29 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
244
 
245
  order = _convert_schedule_table_to_order(
246
  num_warmup_microbatches,
247
- num_model_chunks=config.num_stages_per_device,
248
  schedule_table=schedule_table,
249
  )
250
 
251
  cur_stage_microbatch_id = {}
252
- for i in range(1, config.num_stages_per_device+1):
253
  cur_stage_microbatch_id[i] = 0
254
  cur_stage_microbatch_id[-i] = 0
255
  for order_item in order:
256
- stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
257
 
258
  if order_item > 0:
259
  op_type = "forward"
260
  micro_batch_id = cur_stage_microbatch_id[order_item]
261
- cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
 
 
262
  elif order_item < 0:
263
  op_type = "backward"
264
  micro_batch_id = cur_stage_microbatch_id[order_item]
265
- cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
 
 
266
  else:
267
  raise ValueError(f"Invalid order item: {order_item}")
268
  schedule.device_queues[device_id].add_operation(
@@ -270,6 +291,7 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
270
  )
271
  return schedule
272
 
 
273
  def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
274
  schedule = Schedule(config)
275
 
@@ -290,15 +312,15 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
290
  )
291
 
292
  # NOTE: Add one more warmup microbatch for overlapped operations!
293
- num_warmup_microbatches += 1
294
  order = _convert_schedule_table_to_order(
295
  num_warmup_microbatches,
296
- num_model_chunks=config.num_stages_per_device,
297
  schedule_table=schedule_table,
298
  )
299
 
300
  cur_stage_microbatch_id = {}
301
- for i in range(1, config.num_stages_per_device+1):
302
  cur_stage_microbatch_id[i] = 0
303
  cur_stage_microbatch_id[-i] = 0
304
  i = 0
@@ -310,27 +332,40 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
310
  assert order_item > 0
311
  op_type = "forward"
312
  micro_batch_id = cur_stage_microbatch_id[order_item]
313
- cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
 
 
314
 
315
- stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
316
  schedule.device_queues[device_id].add_operation(
317
  schedule.get_op(micro_batch_id, stage_id, op_type)
318
  )
319
  i += 1
320
- elif i >= num_warmup_microbatches and i < num_warmup_microbatches + num_overlapped_batches - 1:
 
 
 
321
  order_item_a = order[i]
322
- order_item_b = order[i+1]
323
 
324
  op_type_a = "forward" if order_item_a > 0 else "backward"
325
  micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
326
- cur_stage_microbatch_id[order_item_a] = cur_stage_microbatch_id[order_item_a] + 1
 
 
327
 
328
  op_type_b = "forward" if order_item_b > 0 else "backward"
329
  micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
330
- cur_stage_microbatch_id[order_item_b] = cur_stage_microbatch_id[order_item_b] + 1
 
 
331
 
332
- stage_id_a = schedule.device_queues[device_id].stages[abs(order_item_a)-1]
333
- stage_id_b = schedule.device_queues[device_id].stages[abs(order_item_b)-1]
 
 
 
 
334
 
335
  op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
336
  op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
@@ -345,14 +380,15 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
345
  assert order_item < 0
346
  op_type = "backward"
347
  micro_batch_id = cur_stage_microbatch_id[order_item]
348
- cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
 
 
349
 
350
- stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
351
  schedule.device_queues[device_id].add_operation(
352
  schedule.get_op(micro_batch_id, stage_id, op_type)
353
  )
354
  i += 1
355
-
356
 
357
  return schedule
358
 
@@ -365,23 +401,23 @@ def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2
365
  # Get the operations from the schedule
366
  op1 = schedule.ops[(batch_id1, stage_id, type1)]
367
  op2 = schedule.ops[(batch_id2, stage_id, type2)]
368
-
369
  # Create the overlapped operation
370
  overlapped_op = OverlappedOperation([op1, op2])
371
-
372
  # Register in the schedule to ensure proper tracking
373
  schedule.register_overlapped_operation(overlapped_op)
374
-
375
  return overlapped_op
376
 
377
 
378
  def generate_dualpipe_schedule(config: ScheduleConfig):
379
  """
380
  Implements the DualPipe scheduling strategy.
381
-
382
  DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
383
  and backward computation-communication phases and reduces pipeline bubbles.
384
-
385
  The DualPipe strategy has the following characteristics:
386
  1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
387
  2. Each device handles both a forward stage and a reverse stage
@@ -396,15 +432,27 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
396
  A Schedule object with the DualPipe scheduling
397
  """
398
  # Ensure placement strategy is set for Schedule initialization
399
- assert config.placement_strategy == "dualpipe", "DualPipe schedule currently only supports placement_strategy='dualpipe'"
 
 
400
  # Assertions based on DualPipe requirements
401
- assert config.num_stages % 2 == 0, "DualPipe requires an even number of stages (and devices)"
402
- assert config.num_devices == config.num_stages, "DualPipe requires num_devices == num_stages"
403
- assert config.num_batches % 2 == 0, "DualPipe requires an even number of microbatches (config.num_batches)"
 
 
 
 
 
 
404
  # Assertion based on original implementation: num_chunks >= num_ranks * 2
405
  # Here, M (config.num_batches) corresponds to half_num_chunks
406
- assert config.num_batches >= config.num_devices, "DualPipe requires config.num_batches >= config.num_devices"
407
- assert config.split_backward, "DualPipe schedule currently only supports split_backward=True"
 
 
 
 
408
 
409
  schedule = Schedule(config, init_ops=False)
410
 
@@ -414,10 +462,12 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
414
  half_num_chunks = config.num_batches // 2
415
  num_half_ranks = num_devices // 2
416
 
417
- fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
418
- bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
419
 
420
- waited_weight_grad = [deque() for _ in range(num_devices)] # (device_id, ) -> List[(stage_id, batch_id)]
 
 
421
 
422
  for device_id in range(num_devices):
423
  is_in_second_half = device_id >= num_half_ranks
@@ -431,16 +481,18 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
431
  fwd_batch_ids[device_id, 1] = config.num_batches // 2
432
  bwd_d_batch_ids[device_id, 0] = 0
433
  bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
 
434
  def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
435
- stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
436
- stage_rev_dir = num_stages - 1 - device_id # Stage handled when moving backward (N-1 to 0)
 
 
437
  if not is_in_second_half:
438
  # First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
439
  return stage_fwd_dir if phase == 0 else stage_rev_dir
440
  else:
441
  # Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
442
  return stage_rev_dir if phase == 0 else stage_fwd_dir
443
-
444
 
445
  def add_op_to_queue(device_id, stage_id, op_type, batch_id):
446
  # Retrieve the correct pre-initialized Operation object
@@ -462,7 +514,7 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
462
  batch_id = bwd_d_batch_ids[device_id, phase]
463
  add_op_to_queue(device_id, stage_id, "backward", batch_id)
464
  bwd_d_batch_ids[device_id, phase] += 1
465
-
466
  def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
467
  """Schedules a backward_D compute operation."""
468
  stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
@@ -476,11 +528,17 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
476
  stage_id, batch_id = waited_weight_grad[device_id].popleft()
477
  add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
478
 
479
- def _schedule_forward_backward_chunk(device_id, fwd_phase, bwd_phase, is_in_second_half):
 
 
480
  """Schedules an overlapped forward and backward_D compute operation."""
481
- fwd_stage_id = get_stage_for_phase(device_id, fwd_phase, num_stages, is_in_second_half)
482
- bwd_stage_id = get_stage_for_phase(device_id, bwd_phase, num_stages, is_in_second_half)
483
-
 
 
 
 
484
  fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
485
 
486
  fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
@@ -493,58 +551,67 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
493
  bwd_d_batch_ids[device_id, bwd_phase] += 1
494
 
495
  # Create and register the overlapped operation
496
- overlapped_op = OverlappedOperation([fwd_op, bwd_op])
497
  schedule.register_overlapped_operation(overlapped_op)
498
-
499
  # Add the overlapped operation to the queue
500
  schedule.device_queues[device_id].add_operation(overlapped_op)
501
 
502
-
503
  # Process each device (rank in original code)
504
  for device_id in range(num_devices):
505
  half_rank = min(device_id, num_devices - 1 - device_id)
506
  is_in_second_half = device_id >= num_half_ranks
507
- is_middle_rank = (device_id == num_half_ranks - 1) or (device_id == num_half_ranks)
 
 
508
 
509
  # Map original steps to operation additions
510
  # Step 1: nF0
511
  step_1_count = (num_half_ranks - half_rank - 1) * 2
512
  for _ in range(step_1_count):
513
- _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
514
 
515
  # Step 2: nF0F1
516
  step_2_count = half_rank + 1
517
  for i in range(step_2_count):
518
- _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
519
- _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
520
 
521
  # Step 3: nB1W1F1
522
  step_3_count = num_half_ranks - half_rank - 1
523
  for _ in range(step_3_count):
524
- _schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
525
- _schedule_backward_weight_chunk(device_id,) # W1
 
 
526
  _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
527
 
528
  # Step 4 (Main step): nF0B1F1B0
529
  step_4_count = half_num_chunks - num_devices + half_rank + 1
530
  for i in range(step_4_count):
531
  # if i == 0 and is_middle_rank:
532
- # Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
533
- # _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
534
- # _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
535
- # _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
536
  # else:
537
  # Overlap F0 and B1_D, then schedule W1
538
- _schedule_forward_backward_chunk(device_id, 0, 1, is_in_second_half) # F0+B1
539
-
 
 
540
  # Overlap F1 and B0_D, then schedule W0
541
- _schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
 
 
542
 
543
  # Step 5: nB1F1B0
544
  step_5_count = num_half_ranks - half_rank - 1
545
  for _ in range(step_5_count):
546
- _schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
547
- _schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
 
 
548
 
549
  # Step 6: nB1B0
550
  step_6_count = half_rank + 1
@@ -566,8 +633,10 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
566
  # Step 7: nWB0
567
  step_7_count = num_half_ranks - half_rank - 1
568
  for _ in range(step_7_count):
569
- _schedule_backward_weight_chunk(device_id) # W1 (use gradient from B1_D scheduled previously)
570
- _schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
 
 
571
 
572
  # Step 8: nW
573
  step_8_count = half_rank + 1
@@ -575,7 +644,200 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
575
  # W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
576
  # W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
577
  # The last W0 gradients correspond to B0_D from step 6 or 7.
578
- _schedule_backward_weight_chunk(device_id) # W0 (use gradient from B0_D scheduled previously)
 
 
579
 
580
  return schedule
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def generate_1f1b_schedule(config: ScheduleConfig):
6
  schedule = Schedule(config)
7
 
8
+ assert (
9
+ config.num_devices == config.num_stages
10
+ ), "num_devices must be equal to num_stages for 1F1B"
11
 
12
  for i in range(config.num_devices):
13
  fwd_batch_id = 0
 
44
  # Create a new schedule with split_backward=True to support backward_D and backward_W operations
45
  schedule = Schedule(config)
46
  total_batches = config.num_batches
47
+ assert (
48
+ config.num_devices == config.num_stages
49
+ ), "num_devices must be equal to num_stages for ZB-1P"
50
  assert config.split_backward, "ZB-1P requires split_backward=True"
51
 
52
  for i in range(config.num_devices):
 
77
  bwd_w_batch_id += 1
78
  bwd_d_batch_id += 1
79
  fwd_batch_id += 1
80
+
81
  for _ in range(cooldown_batches):
82
  schedule.device_queues[i].add_operation(
83
  schedule.get_op(bwd_d_batch_id, i, "backward_D")
 
89
 
90
  bwd_w_batch_id += 1
91
  bwd_d_batch_id += 1
92
+
93
  while bwd_w_batch_id < total_batches:
94
  schedule.device_queues[i].add_operation(
95
  schedule.get_op(bwd_w_batch_id, i, "backward_W")
 
102
  def generate_1f1b_overlap_schedule(config: ScheduleConfig):
103
  schedule = Schedule(config)
104
 
105
+ assert (
106
+ config.num_devices == config.num_stages
107
+ ), "num_devices must be equal to num_stages for 1F1B"
108
 
109
  for i in range(config.num_devices):
110
  fwd_batch_id = 0
 
138
 
139
 
140
  def _get_pp_rank_microbatches(
141
+ num_microbatches,
142
  num_devices,
143
  device_id,
144
+ num_stages_per_device,
145
+ microbatch_group_size_per_vp_stage,
146
  ):
147
  """Get the number of total, warmup, and remaining microbatches in PP scheduling."""
148
  total_num_microbatches = num_microbatches * num_stages_per_device
 
153
  # stage ID (more forward passes for earlier stages, later stages can
154
  # immediately start with 1F1B).
155
  num_warmup_microbatches = (num_devices - device_id - 1) * 2
156
+ num_warmup_microbatches += (
157
+ num_stages_per_device - 1
158
+ ) * microbatch_group_size_per_vp_stage
159
  else:
160
  # forward_backward_no_pipelining
161
  num_warmup_microbatches = 1
 
166
  return num_warmup_microbatches
167
 
168
 
169
+ def _get_schedule_table(
170
+ num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage
171
+ ):
172
  """Get the schedule table for PP scheduling.
173
 
174
  Create a tunable schedule lookup table.
175
+ The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
176
  For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
177
  virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
178
  microbatch_id | 0 1 2 0 1 2 3 4 3 4
179
+ model_chunk_id | 0 0 0 1 1 1 0 0 1 1
180
  """
181
  schedule_table = []
182
  for min_microbatch_id_in_group in range(
183
  0, num_microbatches, microbatch_group_size_per_vp_stage
184
  ):
185
+ if (
186
+ min_microbatch_id_in_group + microbatch_group_size_per_vp_stage
187
+ >= num_microbatches
188
+ ):
189
  # Construct schedule for the last microbatch group
190
  schedule_table.extend(
191
  [
192
  (microbatch_id, model_chunk_id)
193
  for model_chunk_id in range(num_model_chunks)
194
+ for microbatch_id in range(
195
+ min_microbatch_id_in_group, num_microbatches
196
+ )
197
  ]
198
  )
199
  else:
 
211
  return schedule_table
212
 
213
 
214
+ def _convert_schedule_table_to_order(
215
+ num_warmup_microbatches, num_model_chunks, schedule_table
216
+ ):
217
  """Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
218
  order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
219
  virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
 
242
  # Some codes are copied from Megatron-LM
243
  def generate_1f1b_interleave_schedule(config: ScheduleConfig):
244
  schedule = Schedule(config)
245
+
246
  for device_id in range(config.num_devices):
247
  microbatch_group_size_per_vp_stage = config.num_devices
248
  num_warmup_microbatches = _get_pp_rank_microbatches(
 
261
 
262
  order = _convert_schedule_table_to_order(
263
  num_warmup_microbatches,
264
+ num_model_chunks=config.num_stages_per_device,
265
  schedule_table=schedule_table,
266
  )
267
 
268
  cur_stage_microbatch_id = {}
269
+ for i in range(1, config.num_stages_per_device + 1):
270
  cur_stage_microbatch_id[i] = 0
271
  cur_stage_microbatch_id[-i] = 0
272
  for order_item in order:
273
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
274
 
275
  if order_item > 0:
276
  op_type = "forward"
277
  micro_batch_id = cur_stage_microbatch_id[order_item]
278
+ cur_stage_microbatch_id[order_item] = (
279
+ cur_stage_microbatch_id[order_item] + 1
280
+ )
281
  elif order_item < 0:
282
  op_type = "backward"
283
  micro_batch_id = cur_stage_microbatch_id[order_item]
284
+ cur_stage_microbatch_id[order_item] = (
285
+ cur_stage_microbatch_id[order_item] + 1
286
+ )
287
  else:
288
  raise ValueError(f"Invalid order item: {order_item}")
289
  schedule.device_queues[device_id].add_operation(
 
291
  )
292
  return schedule
293
 
294
+
295
  def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
296
  schedule = Schedule(config)
297
 
 
312
  )
313
 
314
  # NOTE: Add one more warmup microbatch for overlapped operations!
315
+ num_warmup_microbatches += 1
316
  order = _convert_schedule_table_to_order(
317
  num_warmup_microbatches,
318
+ num_model_chunks=config.num_stages_per_device,
319
  schedule_table=schedule_table,
320
  )
321
 
322
  cur_stage_microbatch_id = {}
323
+ for i in range(1, config.num_stages_per_device + 1):
324
  cur_stage_microbatch_id[i] = 0
325
  cur_stage_microbatch_id[-i] = 0
326
  i = 0
 
332
  assert order_item > 0
333
  op_type = "forward"
334
  micro_batch_id = cur_stage_microbatch_id[order_item]
335
+ cur_stage_microbatch_id[order_item] = (
336
+ cur_stage_microbatch_id[order_item] + 1
337
+ )
338
 
339
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
340
  schedule.device_queues[device_id].add_operation(
341
  schedule.get_op(micro_batch_id, stage_id, op_type)
342
  )
343
  i += 1
344
+ elif (
345
+ i >= num_warmup_microbatches
346
+ and i < num_warmup_microbatches + num_overlapped_batches - 1
347
+ ):
348
  order_item_a = order[i]
349
+ order_item_b = order[i + 1]
350
 
351
  op_type_a = "forward" if order_item_a > 0 else "backward"
352
  micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
353
+ cur_stage_microbatch_id[order_item_a] = (
354
+ cur_stage_microbatch_id[order_item_a] + 1
355
+ )
356
 
357
  op_type_b = "forward" if order_item_b > 0 else "backward"
358
  micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
359
+ cur_stage_microbatch_id[order_item_b] = (
360
+ cur_stage_microbatch_id[order_item_b] + 1
361
+ )
362
 
363
+ stage_id_a = schedule.device_queues[device_id].stages[
364
+ abs(order_item_a) - 1
365
+ ]
366
+ stage_id_b = schedule.device_queues[device_id].stages[
367
+ abs(order_item_b) - 1
368
+ ]
369
 
370
  op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
371
  op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
 
380
  assert order_item < 0
381
  op_type = "backward"
382
  micro_batch_id = cur_stage_microbatch_id[order_item]
383
+ cur_stage_microbatch_id[order_item] = (
384
+ cur_stage_microbatch_id[order_item] + 1
385
+ )
386
 
387
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
388
  schedule.device_queues[device_id].add_operation(
389
  schedule.get_op(micro_batch_id, stage_id, op_type)
390
  )
391
  i += 1
 
392
 
393
  return schedule
394
 
 
401
  # Get the operations from the schedule
402
  op1 = schedule.ops[(batch_id1, stage_id, type1)]
403
  op2 = schedule.ops[(batch_id2, stage_id, type2)]
404
+
405
  # Create the overlapped operation
406
  overlapped_op = OverlappedOperation([op1, op2])
407
+
408
  # Register in the schedule to ensure proper tracking
409
  schedule.register_overlapped_operation(overlapped_op)
410
+
411
  return overlapped_op
412
 
413
 
414
  def generate_dualpipe_schedule(config: ScheduleConfig):
415
  """
416
  Implements the DualPipe scheduling strategy.
417
+
418
  DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
419
  and backward computation-communication phases and reduces pipeline bubbles.
420
+
421
  The DualPipe strategy has the following characteristics:
422
  1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
423
  2. Each device handles both a forward stage and a reverse stage
 
432
  A Schedule object with the DualPipe scheduling
433
  """
434
  # Ensure placement strategy is set for Schedule initialization
435
+ assert (
436
+ config.placement_strategy == "dualpipe"
437
+ ), "DualPipe schedule currently only supports placement_strategy='dualpipe'"
438
  # Assertions based on DualPipe requirements
439
+ assert (
440
+ config.num_stages % 2 == 0
441
+ ), "DualPipe requires an even number of stages (and devices)"
442
+ assert (
443
+ config.num_devices == config.num_stages
444
+ ), "DualPipe requires num_devices == num_stages"
445
+ assert (
446
+ config.num_batches % 2 == 0
447
+ ), "DualPipe requires an even number of microbatches (config.num_batches)"
448
  # Assertion based on original implementation: num_chunks >= num_ranks * 2
449
  # Here, M (config.num_batches) corresponds to half_num_chunks
450
+ assert (
451
+ config.num_batches >= config.num_devices
452
+ ), "DualPipe requires config.num_batches >= config.num_devices"
453
+ assert (
454
+ config.split_backward
455
+ ), "DualPipe schedule currently only supports split_backward=True"
456
 
457
  schedule = Schedule(config, init_ops=False)
458
 
 
462
  half_num_chunks = config.num_batches // 2
463
  num_half_ranks = num_devices // 2
464
 
465
+ fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
466
+ bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
467
 
468
+ waited_weight_grad = [
469
+ deque() for _ in range(num_devices)
470
+ ] # (device_id, ) -> List[(stage_id, batch_id)]
471
 
472
  for device_id in range(num_devices):
473
  is_in_second_half = device_id >= num_half_ranks
 
481
  fwd_batch_ids[device_id, 1] = config.num_batches // 2
482
  bwd_d_batch_ids[device_id, 0] = 0
483
  bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
484
+
485
  def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
486
+ stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
487
+ stage_rev_dir = (
488
+ num_stages - 1 - device_id
489
+ ) # Stage handled when moving backward (N-1 to 0)
490
  if not is_in_second_half:
491
  # First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
492
  return stage_fwd_dir if phase == 0 else stage_rev_dir
493
  else:
494
  # Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
495
  return stage_rev_dir if phase == 0 else stage_fwd_dir
 
496
 
497
  def add_op_to_queue(device_id, stage_id, op_type, batch_id):
498
  # Retrieve the correct pre-initialized Operation object
 
514
  batch_id = bwd_d_batch_ids[device_id, phase]
515
  add_op_to_queue(device_id, stage_id, "backward", batch_id)
516
  bwd_d_batch_ids[device_id, phase] += 1
517
+
518
  def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
519
  """Schedules a backward_D compute operation."""
520
  stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
 
528
  stage_id, batch_id = waited_weight_grad[device_id].popleft()
529
  add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
530
 
531
+ def _schedule_forward_backward_chunk(
532
+ device_id, fwd_phase, bwd_phase, is_in_second_half
533
+ ):
534
  """Schedules an overlapped forward and backward_D compute operation."""
535
+ fwd_stage_id = get_stage_for_phase(
536
+ device_id, fwd_phase, num_stages, is_in_second_half
537
+ )
538
+ bwd_stage_id = get_stage_for_phase(
539
+ device_id, bwd_phase, num_stages, is_in_second_half
540
+ )
541
+
542
  fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
543
 
544
  fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
 
551
  bwd_d_batch_ids[device_id, bwd_phase] += 1
552
 
553
  # Create and register the overlapped operation
554
+ overlapped_op = OverlappedOperation([fwd_op, bwd_op])
555
  schedule.register_overlapped_operation(overlapped_op)
556
+
557
  # Add the overlapped operation to the queue
558
  schedule.device_queues[device_id].add_operation(overlapped_op)
559
 
 
560
  # Process each device (rank in original code)
561
  for device_id in range(num_devices):
562
  half_rank = min(device_id, num_devices - 1 - device_id)
563
  is_in_second_half = device_id >= num_half_ranks
564
+ is_middle_rank = (device_id == num_half_ranks - 1) or (
565
+ device_id == num_half_ranks
566
+ )
567
 
568
  # Map original steps to operation additions
569
  # Step 1: nF0
570
  step_1_count = (num_half_ranks - half_rank - 1) * 2
571
  for _ in range(step_1_count):
572
+ _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
573
 
574
  # Step 2: nF0F1
575
  step_2_count = half_rank + 1
576
  for i in range(step_2_count):
577
+ _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
578
+ _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
579
 
580
  # Step 3: nB1W1F1
581
  step_3_count = num_half_ranks - half_rank - 1
582
  for _ in range(step_3_count):
583
+ _schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
584
+ _schedule_backward_weight_chunk(
585
+ device_id,
586
+ ) # W1
587
  _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
588
 
589
  # Step 4 (Main step): nF0B1F1B0
590
  step_4_count = half_num_chunks - num_devices + half_rank + 1
591
  for i in range(step_4_count):
592
  # if i == 0 and is_middle_rank:
593
+ # Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
594
+ # _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
595
+ # _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
596
+ # _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
597
  # else:
598
  # Overlap F0 and B1_D, then schedule W1
599
+ _schedule_forward_backward_chunk(
600
+ device_id, 0, 1, is_in_second_half
601
+ ) # F0+B1
602
+
603
  # Overlap F1 and B0_D, then schedule W0
604
+ _schedule_forward_backward_chunk(
605
+ device_id, 1, 0, is_in_second_half
606
+ ) # F1+B0
607
 
608
  # Step 5: nB1F1B0
609
  step_5_count = num_half_ranks - half_rank - 1
610
  for _ in range(step_5_count):
611
+ _schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
612
+ _schedule_forward_backward_chunk(
613
+ device_id, 1, 0, is_in_second_half
614
+ ) # F1+B0
615
 
616
  # Step 6: nB1B0
617
  step_6_count = half_rank + 1
 
633
  # Step 7: nWB0
634
  step_7_count = num_half_ranks - half_rank - 1
635
  for _ in range(step_7_count):
636
+ _schedule_backward_weight_chunk(
637
+ device_id
638
+ ) # W1 (use gradient from B1_D scheduled previously)
639
+ _schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
640
 
641
  # Step 8: nW
642
  step_8_count = half_rank + 1
 
644
  # W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
645
  # W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
646
  # The last W0 gradients correspond to B0_D from step 6 or 7.
647
+ _schedule_backward_weight_chunk(
648
+ device_id
649
+ ) # W0 (use gradient from B0_D scheduled previously)
650
 
651
  return schedule
652
 
653
+
654
+ def generate_dualpipe_v_schedule(config: ScheduleConfig):
655
+ """
656
+ Implements the DualPipe-V scheduling strategy based on dualpipe_v.py.
657
+
658
+ DualPipe-V aims to improve upon DualPipe by utilizing Zero Bubble (ZB)
659
+ techniques, further reducing pipeline bubbles by overlapping gradient
660
+ computation (backward_D) and weight updates (backward_W).
661
+
662
+ Key characteristics:
663
+ 1. Requires placement_strategy="dualpipe".
664
+ 2. Each device handles a forward stage and a reverse stage.
665
+ 3. Requires split_backward=True.
666
+ 4. Overlaps forward (F) and backward_D (B_D) operations.
667
+ 5. Schedules backward_W (W) operations separately.
668
+ 6. Uses Zero Bubble logic in later steps to delay W operations.
669
+ 7. Assumes config.num_batches corresponds to the total number of microbatches (`num_chunks` in dualpipe_v.py).
670
+
671
+ Args:
672
+ config: The scheduling configuration.
673
+
674
+ Returns:
675
+ A Schedule object with the DualPipe-V scheduling.
676
+ """
677
+ schedule = Schedule(config, init_ops=False)
678
+
679
+ assert config.num_stages == config.num_devices * 2, "num_stages must be equal to num_devices * 2 for DualPipe-V"
680
+ assert config.split_backward, "DualPipe-V requires split_backward=True"
681
+
682
+ num_stages = config.num_stages
683
+ num_devices = config.num_devices
684
+
685
+ fwd_batch_ids = defaultdict(int) # (device_id, chunk_id) -> batch_id
686
+ bwd_d_batch_ids = defaultdict(int) # (device_id, chunk_id) -> batch_id
687
+
688
+ waited_weight_grad = [
689
+ deque() for _ in range(num_devices)
690
+ ] # (device_id, ) -> List[(stage_id, batch_id)]
691
+
692
+ for device_id in range(num_devices):
693
+ fwd_batch_ids[device_id, 0] = 0
694
+ fwd_batch_ids[device_id, 1] = 0
695
+ bwd_d_batch_ids[device_id, 0] = 0
696
+ bwd_d_batch_ids[device_id, 1] = 0
697
+
698
+
699
+ def add_op_to_queue(device_id, stage_id, op_type, batch_id):
700
+ # Retrieve the correct pre-initialized Operation object
701
+ op = Operation(batch_id, stage_id, op_type)
702
+ schedule.register_operation(op)
703
+ # Add to the device queue
704
+ schedule.device_queues[device_id].add_operation(op)
705
+
706
+ def get_stage_for_chunk(device_id, chunk_id):
707
+ if chunk_id == 0:
708
+ # Forward direction stage for this device
709
+ return device_id
710
+ else:
711
+ # Reverse direction stage for this device
712
+ return num_stages - 1 - device_id
713
+
714
+ def _schedule_forward_chunk(device_id, chunk_id):
715
+ """Schedules a forward compute operation."""
716
+ stage_id = get_stage_for_chunk(device_id, chunk_id)
717
+ batch_id = fwd_batch_ids[device_id, chunk_id]
718
+ add_op_to_queue(device_id, stage_id, "forward", batch_id)
719
+ fwd_batch_ids[device_id, chunk_id] += 1
720
+
721
+ def _schedule_backward_chunk(device_id, chunk_id, enable_zb=False):
722
+ """Schedules a backward_D compute operation."""
723
+ stage_id = get_stage_for_chunk(device_id, chunk_id)
724
+ batch_id = bwd_d_batch_ids[device_id, chunk_id]
725
+ if enable_zb:
726
+ add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
727
+ waited_weight_grad[device_id].append((stage_id, batch_id))
728
+ else:
729
+ add_op_to_queue(device_id, stage_id, "backward", batch_id)
730
+ bwd_d_batch_ids[device_id, chunk_id] += 1
731
+
732
+ def _schedule_backward_weight_chunk(device_id):
733
+ """Schedules a backward_W compute operation."""
734
+ assert waited_weight_grad[device_id], f"Device {device_id} has no waited weight grads to schedule"
735
+ stage_id, batch_id = waited_weight_grad[device_id].popleft()
736
+ add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
737
+
738
+ def _schedule_forward_backward_chunk(
739
+ device_id, fwd_chunk_id, bwd_chunk_id
740
+ ):
741
+ """Schedules an overlapped forward and backward_D compute operation."""
742
+ fwd_stage_id = get_stage_for_chunk(device_id, fwd_chunk_id)
743
+ bwd_stage_id = get_stage_for_chunk(device_id, bwd_chunk_id)
744
+
745
+ fwd_batch_id = fwd_batch_ids[device_id, fwd_chunk_id]
746
+ fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
747
+ schedule.register_operation(fwd_op)
748
+ fwd_batch_ids[device_id, fwd_chunk_id] += 1
749
+
750
+ bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_chunk_id]
751
+ # Schedule backward_D
752
+ bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
753
+ schedule.register_operation(bwd_op)
754
+ bwd_d_batch_ids[device_id, bwd_chunk_id] += 1
755
+
756
+ # Create and register the overlapped operation
757
+ overlapped_op = OverlappedOperation([fwd_op, bwd_op])
758
+ schedule.register_overlapped_operation(overlapped_op)
759
+
760
+ # Add the overlapped operation to the queue
761
+ schedule.device_queues[device_id].add_operation(overlapped_op)
762
+
763
+ # Process each device (rank in original code)
764
+ for device_id in range(num_devices):
765
+ # Step 1: nF0
766
+ step_1_count = (num_devices - device_id - 1) * 2
767
+ for _ in range(step_1_count):
768
+ _schedule_forward_chunk(device_id, 0) # F0
769
+
770
+ # Step 2: nF0F1
771
+ step_2_count = device_id + 1
772
+ for i in range(step_2_count):
773
+ _schedule_forward_chunk(device_id, 0) # F0
774
+ _schedule_forward_chunk(device_id, 1) # F1
775
+
776
+ # Step 3: nB1W1F1 (Use zero bubble for B1)
777
+ step_3_count = num_devices - device_id - 1
778
+ for _ in range(step_3_count):
779
+ _schedule_backward_chunk(device_id, 1, enable_zb=True) # B1_D (ZB enabled)
780
+ _schedule_backward_weight_chunk(device_id) # W1
781
+ _schedule_forward_chunk(device_id, 1) # F1
782
+
783
+ # Step 4 (Main step): nF0B1F1B0 (Overlapped F and B_D)
784
+ num_batches = config.num_batches
785
+ step_4_count = num_batches - num_devices * 2 + device_id + 1
786
+ is_last_rank = (device_id == num_devices - 1) # Check if it's the last rank
787
+
788
+ for i in range(step_4_count):
789
+ if i == 0:
790
+ if is_last_rank:
791
+ # Special handling for the first iteration on the last rank
792
+ # Schedule F0, B1, W1 sequentially
793
+ _schedule_forward_chunk(device_id, 0) # F0
794
+ _schedule_backward_chunk(device_id, 1, enable_zb=False) # B1_D
795
+ else:
796
+ # Overlap F0 and B1
797
+ _schedule_forward_backward_chunk(device_id, 0, 1) # F0 + B1_D
798
+ else:
799
+ # Overlap F1 and B0_D
800
+ _schedule_forward_backward_chunk(device_id, 0, 1) # F0B1
801
+ _schedule_forward_backward_chunk(device_id, 1, 0) #
802
+
803
+
804
+ # Step 5: nB1F1B0
805
+ step_5_count = num_devices - device_id - 1
806
+ for _ in range(step_5_count):
807
+ # Schedule B1 (B1_D + B1_W) sequentially
808
+ _schedule_backward_chunk(device_id, 1, enable_zb=False) # B1_D + W1
809
+
810
+ # Overlap F1 and B0
811
+ _schedule_forward_backward_chunk(device_id, 1, 0) # F1 + B0
812
+
813
+ # Step 6: nB1B0 (The second half of the chunks use zero bubble)
814
+ step_6_count = device_id + 1
815
+ enable_zb = False
816
+ for i in range(step_6_count):
817
+ # Determine if ZB should be enabled for B1
818
+ if i == step_6_count // 2 and device_id % 2 == 1:
819
+ enable_zb = True
820
+ _schedule_backward_chunk(device_id, 1, enable_zb=enable_zb) # B1_D
821
+
822
+ # Determine if ZB should be enabled for B0
823
+ # ZB is enabled after the midpoint check for B0
824
+ if i == step_6_count // 2 and device_id % 2 == 0:
825
+ enable_zb = True # Enable ZB for the rest, including B0
826
+ _schedule_backward_chunk(device_id, 0, enable_zb=enable_zb) # B0_D
827
+
828
+ # Step 7: nWB0 (Use zero bubble for B0)
829
+ step_7_count = num_devices - device_id - 1
830
+ for _ in range(step_7_count):
831
+ _schedule_backward_weight_chunk(device_id) # W1 (from ZB B1_D in Step 6 or Step 3)
832
+ _schedule_backward_chunk(device_id, 0, enable_zb=True) # B0_D
833
+
834
+ # Step 8: nW
835
+ step_8_count = device_id + 1
836
+ for _ in range(step_8_count):
837
+ _schedule_backward_weight_chunk(device_id) # W0 (from ZB B0_D in Step 6 or 7) or W1 (from ZB B1_D in Step 6)
838
+
839
+ # Final check: Ensure all waited gradients are processed
840
+ assert not waited_weight_grad[device_id], f"Device {device_id} has remaining waited weight grads: {waited_weight_grad[device_id]}"
841
+
842
+
843
+ return schedule