Spaces:
Running
Running
Add DualPipe-V support.
Browse files- README.md +6 -0
- assets/dualpipe_v.png +3 -0
- main.py +22 -0
- src/execution_model.py +7 -0
- src/strategies.py +340 -78
README.md
CHANGED
@@ -84,6 +84,12 @@ uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=2
|
|
84 |
```
|
85 |

|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
### Running for 1F1B-batch-overlap strategy:
|
88 |
```bash
|
89 |
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
|
|
84 |
```
|
85 |

|
86 |
|
87 |
+
### Running for DualPipe-V strategy
|
88 |
+
```bash
|
89 |
+
uv run python main.py strategy=dualpipe_v num_devices=4 num_stages=8 num_batches=10
|
90 |
+
```
|
91 |
+

|
92 |
+
|
93 |
### Running for 1F1B-batch-overlap strategy:
|
94 |
```bash
|
95 |
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
assets/dualpipe_v.png
ADDED
![]() |
Git LFS Details
|
main.py
CHANGED
@@ -4,6 +4,7 @@ from src.strategies import (
|
|
4 |
generate_1f1b_interleave_schedule,
|
5 |
generate_1f1b_overlap_schedule,
|
6 |
generate_1f1b_schedule,
|
|
|
7 |
generate_zero_bubble_1p_schedule,
|
8 |
generate_dualpipe_schedule,
|
9 |
)
|
@@ -29,6 +30,8 @@ def main(cfg: DictConfig) -> None:
|
|
29 |
run_1f1b_interleave_overlap(cfg)
|
30 |
elif cfg.strategy == "dualpipe":
|
31 |
run_dualpipe(cfg)
|
|
|
|
|
32 |
else:
|
33 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
34 |
|
@@ -152,5 +155,24 @@ def run_dualpipe(cfg: DictConfig) -> None:
|
|
152 |
schedule.execute()
|
153 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
if __name__ == "__main__":
|
156 |
main()
|
|
|
4 |
generate_1f1b_interleave_schedule,
|
5 |
generate_1f1b_overlap_schedule,
|
6 |
generate_1f1b_schedule,
|
7 |
+
generate_dualpipe_v_schedule,
|
8 |
generate_zero_bubble_1p_schedule,
|
9 |
generate_dualpipe_schedule,
|
10 |
)
|
|
|
30 |
run_1f1b_interleave_overlap(cfg)
|
31 |
elif cfg.strategy == "dualpipe":
|
32 |
run_dualpipe(cfg)
|
33 |
+
elif cfg.strategy == "dualpipe_v":
|
34 |
+
run_dualpipe_v(cfg)
|
35 |
else:
|
36 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
37 |
|
|
|
155 |
schedule.execute()
|
156 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
157 |
|
158 |
+
def run_dualpipe_v(cfg: DictConfig) -> None:
|
159 |
+
"""Run DualPipeV pipeline parallelism simulation."""
|
160 |
+
# Convert OmegaConf to dict for op_times if it exists
|
161 |
+
op_times = (
|
162 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
163 |
+
)
|
164 |
+
schedule_config = ScheduleConfig(
|
165 |
+
num_devices=cfg.num_devices,
|
166 |
+
num_stages=cfg.num_stages,
|
167 |
+
num_batches=cfg.num_batches,
|
168 |
+
p2p_latency=cfg.p2p_latency,
|
169 |
+
op_times=op_times,
|
170 |
+
split_backward=True,
|
171 |
+
placement_strategy="dualpipe_v",
|
172 |
+
)
|
173 |
+
schedule = generate_dualpipe_v_schedule(schedule_config)
|
174 |
+
schedule.execute()
|
175 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
176 |
+
|
177 |
if __name__ == "__main__":
|
178 |
main()
|
src/execution_model.py
CHANGED
@@ -158,6 +158,13 @@ class ScheduleConfig:
|
|
158 |
self.device_to_stages = defaultdict(list)
|
159 |
for i in range(self.num_stages):
|
160 |
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
else:
|
162 |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
163 |
|
|
|
158 |
self.device_to_stages = defaultdict(list)
|
159 |
for i in range(self.num_stages):
|
160 |
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
161 |
+
elif self.placement_strategy == "dualpipe_v":
|
162 |
+
assert self.num_devices % 2 == 0, "DualPipe-V requires an even number of devices"
|
163 |
+
assert self.num_stages == self.num_devices * 2, "DualPipe-V requires num_stages == num_devices * 2"
|
164 |
+
assert self.split_backward, "DualPipe-V requires split_backward=True"
|
165 |
+
self.device_to_stages = defaultdict(list)
|
166 |
+
for i in range(self.num_devices):
|
167 |
+
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
168 |
else:
|
169 |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
170 |
|
src/strategies.py
CHANGED
@@ -5,7 +5,9 @@ from src.execution_model import OverlappedOperation, Operation, Schedule, Schedu
|
|
5 |
def generate_1f1b_schedule(config: ScheduleConfig):
|
6 |
schedule = Schedule(config)
|
7 |
|
8 |
-
assert
|
|
|
|
|
9 |
|
10 |
for i in range(config.num_devices):
|
11 |
fwd_batch_id = 0
|
@@ -42,7 +44,9 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
42 |
# Create a new schedule with split_backward=True to support backward_D and backward_W operations
|
43 |
schedule = Schedule(config)
|
44 |
total_batches = config.num_batches
|
45 |
-
assert
|
|
|
|
|
46 |
assert config.split_backward, "ZB-1P requires split_backward=True"
|
47 |
|
48 |
for i in range(config.num_devices):
|
@@ -73,7 +77,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
73 |
bwd_w_batch_id += 1
|
74 |
bwd_d_batch_id += 1
|
75 |
fwd_batch_id += 1
|
76 |
-
|
77 |
for _ in range(cooldown_batches):
|
78 |
schedule.device_queues[i].add_operation(
|
79 |
schedule.get_op(bwd_d_batch_id, i, "backward_D")
|
@@ -85,7 +89,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
85 |
|
86 |
bwd_w_batch_id += 1
|
87 |
bwd_d_batch_id += 1
|
88 |
-
|
89 |
while bwd_w_batch_id < total_batches:
|
90 |
schedule.device_queues[i].add_operation(
|
91 |
schedule.get_op(bwd_w_batch_id, i, "backward_W")
|
@@ -98,7 +102,9 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
98 |
def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
99 |
schedule = Schedule(config)
|
100 |
|
101 |
-
assert
|
|
|
|
|
102 |
|
103 |
for i in range(config.num_devices):
|
104 |
fwd_batch_id = 0
|
@@ -132,11 +138,11 @@ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
|
132 |
|
133 |
|
134 |
def _get_pp_rank_microbatches(
|
135 |
-
num_microbatches,
|
136 |
num_devices,
|
137 |
device_id,
|
138 |
-
num_stages_per_device,
|
139 |
-
microbatch_group_size_per_vp_stage,
|
140 |
):
|
141 |
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
|
142 |
total_num_microbatches = num_microbatches * num_stages_per_device
|
@@ -147,7 +153,9 @@ def _get_pp_rank_microbatches(
|
|
147 |
# stage ID (more forward passes for earlier stages, later stages can
|
148 |
# immediately start with 1F1B).
|
149 |
num_warmup_microbatches = (num_devices - device_id - 1) * 2
|
150 |
-
num_warmup_microbatches += (
|
|
|
|
|
151 |
else:
|
152 |
# forward_backward_no_pipelining
|
153 |
num_warmup_microbatches = 1
|
@@ -158,27 +166,34 @@ def _get_pp_rank_microbatches(
|
|
158 |
return num_warmup_microbatches
|
159 |
|
160 |
|
161 |
-
def _get_schedule_table(
|
|
|
|
|
162 |
"""Get the schedule table for PP scheduling.
|
163 |
|
164 |
Create a tunable schedule lookup table.
|
165 |
-
The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
|
166 |
For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
|
167 |
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
168 |
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
169 |
-
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
170 |
"""
|
171 |
schedule_table = []
|
172 |
for min_microbatch_id_in_group in range(
|
173 |
0, num_microbatches, microbatch_group_size_per_vp_stage
|
174 |
):
|
175 |
-
if
|
|
|
|
|
|
|
176 |
# Construct schedule for the last microbatch group
|
177 |
schedule_table.extend(
|
178 |
[
|
179 |
(microbatch_id, model_chunk_id)
|
180 |
for model_chunk_id in range(num_model_chunks)
|
181 |
-
for microbatch_id in range(
|
|
|
|
|
182 |
]
|
183 |
)
|
184 |
else:
|
@@ -196,7 +211,9 @@ def _get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_siz
|
|
196 |
return schedule_table
|
197 |
|
198 |
|
199 |
-
def _convert_schedule_table_to_order(
|
|
|
|
|
200 |
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
|
201 |
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
|
202 |
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
@@ -225,7 +242,7 @@ def _convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks,
|
|
225 |
# Some codes are copied from Megatron-LM
|
226 |
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
227 |
schedule = Schedule(config)
|
228 |
-
|
229 |
for device_id in range(config.num_devices):
|
230 |
microbatch_group_size_per_vp_stage = config.num_devices
|
231 |
num_warmup_microbatches = _get_pp_rank_microbatches(
|
@@ -244,25 +261,29 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
|
244 |
|
245 |
order = _convert_schedule_table_to_order(
|
246 |
num_warmup_microbatches,
|
247 |
-
num_model_chunks=config.num_stages_per_device,
|
248 |
schedule_table=schedule_table,
|
249 |
)
|
250 |
|
251 |
cur_stage_microbatch_id = {}
|
252 |
-
for i in range(1, config.num_stages_per_device+1):
|
253 |
cur_stage_microbatch_id[i] = 0
|
254 |
cur_stage_microbatch_id[-i] = 0
|
255 |
for order_item in order:
|
256 |
-
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
257 |
|
258 |
if order_item > 0:
|
259 |
op_type = "forward"
|
260 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
261 |
-
cur_stage_microbatch_id[order_item] =
|
|
|
|
|
262 |
elif order_item < 0:
|
263 |
op_type = "backward"
|
264 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
265 |
-
cur_stage_microbatch_id[order_item] =
|
|
|
|
|
266 |
else:
|
267 |
raise ValueError(f"Invalid order item: {order_item}")
|
268 |
schedule.device_queues[device_id].add_operation(
|
@@ -270,6 +291,7 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
|
270 |
)
|
271 |
return schedule
|
272 |
|
|
|
273 |
def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
274 |
schedule = Schedule(config)
|
275 |
|
@@ -290,15 +312,15 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
|
290 |
)
|
291 |
|
292 |
# NOTE: Add one more warmup microbatch for overlapped operations!
|
293 |
-
num_warmup_microbatches += 1
|
294 |
order = _convert_schedule_table_to_order(
|
295 |
num_warmup_microbatches,
|
296 |
-
num_model_chunks=config.num_stages_per_device,
|
297 |
schedule_table=schedule_table,
|
298 |
)
|
299 |
|
300 |
cur_stage_microbatch_id = {}
|
301 |
-
for i in range(1, config.num_stages_per_device+1):
|
302 |
cur_stage_microbatch_id[i] = 0
|
303 |
cur_stage_microbatch_id[-i] = 0
|
304 |
i = 0
|
@@ -310,27 +332,40 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
|
310 |
assert order_item > 0
|
311 |
op_type = "forward"
|
312 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
313 |
-
cur_stage_microbatch_id[order_item] =
|
|
|
|
|
314 |
|
315 |
-
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
316 |
schedule.device_queues[device_id].add_operation(
|
317 |
schedule.get_op(micro_batch_id, stage_id, op_type)
|
318 |
)
|
319 |
i += 1
|
320 |
-
elif
|
|
|
|
|
|
|
321 |
order_item_a = order[i]
|
322 |
-
order_item_b = order[i+1]
|
323 |
|
324 |
op_type_a = "forward" if order_item_a > 0 else "backward"
|
325 |
micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
|
326 |
-
cur_stage_microbatch_id[order_item_a] =
|
|
|
|
|
327 |
|
328 |
op_type_b = "forward" if order_item_b > 0 else "backward"
|
329 |
micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
|
330 |
-
cur_stage_microbatch_id[order_item_b] =
|
|
|
|
|
331 |
|
332 |
-
stage_id_a = schedule.device_queues[device_id].stages[
|
333 |
-
|
|
|
|
|
|
|
|
|
334 |
|
335 |
op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
|
336 |
op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
|
@@ -345,14 +380,15 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
|
345 |
assert order_item < 0
|
346 |
op_type = "backward"
|
347 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
348 |
-
cur_stage_microbatch_id[order_item] =
|
|
|
|
|
349 |
|
350 |
-
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
351 |
schedule.device_queues[device_id].add_operation(
|
352 |
schedule.get_op(micro_batch_id, stage_id, op_type)
|
353 |
)
|
354 |
i += 1
|
355 |
-
|
356 |
|
357 |
return schedule
|
358 |
|
@@ -365,23 +401,23 @@ def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2
|
|
365 |
# Get the operations from the schedule
|
366 |
op1 = schedule.ops[(batch_id1, stage_id, type1)]
|
367 |
op2 = schedule.ops[(batch_id2, stage_id, type2)]
|
368 |
-
|
369 |
# Create the overlapped operation
|
370 |
overlapped_op = OverlappedOperation([op1, op2])
|
371 |
-
|
372 |
# Register in the schedule to ensure proper tracking
|
373 |
schedule.register_overlapped_operation(overlapped_op)
|
374 |
-
|
375 |
return overlapped_op
|
376 |
|
377 |
|
378 |
def generate_dualpipe_schedule(config: ScheduleConfig):
|
379 |
"""
|
380 |
Implements the DualPipe scheduling strategy.
|
381 |
-
|
382 |
DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
|
383 |
and backward computation-communication phases and reduces pipeline bubbles.
|
384 |
-
|
385 |
The DualPipe strategy has the following characteristics:
|
386 |
1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
|
387 |
2. Each device handles both a forward stage and a reverse stage
|
@@ -396,15 +432,27 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
396 |
A Schedule object with the DualPipe scheduling
|
397 |
"""
|
398 |
# Ensure placement strategy is set for Schedule initialization
|
399 |
-
assert
|
|
|
|
|
400 |
# Assertions based on DualPipe requirements
|
401 |
-
assert
|
402 |
-
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
# Assertion based on original implementation: num_chunks >= num_ranks * 2
|
405 |
# Here, M (config.num_batches) corresponds to half_num_chunks
|
406 |
-
assert
|
407 |
-
|
|
|
|
|
|
|
|
|
408 |
|
409 |
schedule = Schedule(config, init_ops=False)
|
410 |
|
@@ -414,10 +462,12 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
414 |
half_num_chunks = config.num_batches // 2
|
415 |
num_half_ranks = num_devices // 2
|
416 |
|
417 |
-
fwd_batch_ids = defaultdict(int)
|
418 |
-
bwd_d_batch_ids = defaultdict(int)
|
419 |
|
420 |
-
waited_weight_grad = [
|
|
|
|
|
421 |
|
422 |
for device_id in range(num_devices):
|
423 |
is_in_second_half = device_id >= num_half_ranks
|
@@ -431,16 +481,18 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
431 |
fwd_batch_ids[device_id, 1] = config.num_batches // 2
|
432 |
bwd_d_batch_ids[device_id, 0] = 0
|
433 |
bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
|
|
|
434 |
def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
|
435 |
-
stage_fwd_dir = device_id
|
436 |
-
stage_rev_dir =
|
|
|
|
|
437 |
if not is_in_second_half:
|
438 |
# First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
|
439 |
return stage_fwd_dir if phase == 0 else stage_rev_dir
|
440 |
else:
|
441 |
# Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
|
442 |
return stage_rev_dir if phase == 0 else stage_fwd_dir
|
443 |
-
|
444 |
|
445 |
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
446 |
# Retrieve the correct pre-initialized Operation object
|
@@ -462,7 +514,7 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
462 |
batch_id = bwd_d_batch_ids[device_id, phase]
|
463 |
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
464 |
bwd_d_batch_ids[device_id, phase] += 1
|
465 |
-
|
466 |
def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
|
467 |
"""Schedules a backward_D compute operation."""
|
468 |
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
@@ -476,11 +528,17 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
476 |
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
477 |
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
478 |
|
479 |
-
def _schedule_forward_backward_chunk(
|
|
|
|
|
480 |
"""Schedules an overlapped forward and backward_D compute operation."""
|
481 |
-
fwd_stage_id = get_stage_for_phase(
|
482 |
-
|
483 |
-
|
|
|
|
|
|
|
|
|
484 |
fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
|
485 |
|
486 |
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
@@ -493,58 +551,67 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
493 |
bwd_d_batch_ids[device_id, bwd_phase] += 1
|
494 |
|
495 |
# Create and register the overlapped operation
|
496 |
-
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
497 |
schedule.register_overlapped_operation(overlapped_op)
|
498 |
-
|
499 |
# Add the overlapped operation to the queue
|
500 |
schedule.device_queues[device_id].add_operation(overlapped_op)
|
501 |
|
502 |
-
|
503 |
# Process each device (rank in original code)
|
504 |
for device_id in range(num_devices):
|
505 |
half_rank = min(device_id, num_devices - 1 - device_id)
|
506 |
is_in_second_half = device_id >= num_half_ranks
|
507 |
-
is_middle_rank = (device_id == num_half_ranks - 1) or (
|
|
|
|
|
508 |
|
509 |
# Map original steps to operation additions
|
510 |
# Step 1: nF0
|
511 |
step_1_count = (num_half_ranks - half_rank - 1) * 2
|
512 |
for _ in range(step_1_count):
|
513 |
-
_schedule_forward_chunk(device_id, 0, is_in_second_half)
|
514 |
|
515 |
# Step 2: nF0F1
|
516 |
step_2_count = half_rank + 1
|
517 |
for i in range(step_2_count):
|
518 |
-
_schedule_forward_chunk(device_id, 0, is_in_second_half)
|
519 |
-
_schedule_forward_chunk(device_id, 1, is_in_second_half)
|
520 |
|
521 |
# Step 3: nB1W1F1
|
522 |
step_3_count = num_half_ranks - half_rank - 1
|
523 |
for _ in range(step_3_count):
|
524 |
-
_schedule_backward_input_chunk(device_id, 1, is_in_second_half)
|
525 |
-
_schedule_backward_weight_chunk(
|
|
|
|
|
526 |
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
527 |
|
528 |
# Step 4 (Main step): nF0B1F1B0
|
529 |
step_4_count = half_num_chunks - num_devices + half_rank + 1
|
530 |
for i in range(step_4_count):
|
531 |
# if i == 0 and is_middle_rank:
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
# else:
|
537 |
# Overlap F0 and B1_D, then schedule W1
|
538 |
-
_schedule_forward_backward_chunk(
|
539 |
-
|
|
|
|
|
540 |
# Overlap F1 and B0_D, then schedule W0
|
541 |
-
_schedule_forward_backward_chunk(
|
|
|
|
|
542 |
|
543 |
# Step 5: nB1F1B0
|
544 |
step_5_count = num_half_ranks - half_rank - 1
|
545 |
for _ in range(step_5_count):
|
546 |
-
_schedule_backward_chunk(device_id, 1, is_in_second_half)
|
547 |
-
_schedule_forward_backward_chunk(
|
|
|
|
|
548 |
|
549 |
# Step 6: nB1B0
|
550 |
step_6_count = half_rank + 1
|
@@ -566,8 +633,10 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
566 |
# Step 7: nWB0
|
567 |
step_7_count = num_half_ranks - half_rank - 1
|
568 |
for _ in range(step_7_count):
|
569 |
-
_schedule_backward_weight_chunk(
|
570 |
-
|
|
|
|
|
571 |
|
572 |
# Step 8: nW
|
573 |
step_8_count = half_rank + 1
|
@@ -575,7 +644,200 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
575 |
# W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
|
576 |
# W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
|
577 |
# The last W0 gradients correspond to B0_D from step 6 or 7.
|
578 |
-
_schedule_backward_weight_chunk(
|
|
|
|
|
579 |
|
580 |
return schedule
|
581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def generate_1f1b_schedule(config: ScheduleConfig):
|
6 |
schedule = Schedule(config)
|
7 |
|
8 |
+
assert (
|
9 |
+
config.num_devices == config.num_stages
|
10 |
+
), "num_devices must be equal to num_stages for 1F1B"
|
11 |
|
12 |
for i in range(config.num_devices):
|
13 |
fwd_batch_id = 0
|
|
|
44 |
# Create a new schedule with split_backward=True to support backward_D and backward_W operations
|
45 |
schedule = Schedule(config)
|
46 |
total_batches = config.num_batches
|
47 |
+
assert (
|
48 |
+
config.num_devices == config.num_stages
|
49 |
+
), "num_devices must be equal to num_stages for ZB-1P"
|
50 |
assert config.split_backward, "ZB-1P requires split_backward=True"
|
51 |
|
52 |
for i in range(config.num_devices):
|
|
|
77 |
bwd_w_batch_id += 1
|
78 |
bwd_d_batch_id += 1
|
79 |
fwd_batch_id += 1
|
80 |
+
|
81 |
for _ in range(cooldown_batches):
|
82 |
schedule.device_queues[i].add_operation(
|
83 |
schedule.get_op(bwd_d_batch_id, i, "backward_D")
|
|
|
89 |
|
90 |
bwd_w_batch_id += 1
|
91 |
bwd_d_batch_id += 1
|
92 |
+
|
93 |
while bwd_w_batch_id < total_batches:
|
94 |
schedule.device_queues[i].add_operation(
|
95 |
schedule.get_op(bwd_w_batch_id, i, "backward_W")
|
|
|
102 |
def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
103 |
schedule = Schedule(config)
|
104 |
|
105 |
+
assert (
|
106 |
+
config.num_devices == config.num_stages
|
107 |
+
), "num_devices must be equal to num_stages for 1F1B"
|
108 |
|
109 |
for i in range(config.num_devices):
|
110 |
fwd_batch_id = 0
|
|
|
138 |
|
139 |
|
140 |
def _get_pp_rank_microbatches(
|
141 |
+
num_microbatches,
|
142 |
num_devices,
|
143 |
device_id,
|
144 |
+
num_stages_per_device,
|
145 |
+
microbatch_group_size_per_vp_stage,
|
146 |
):
|
147 |
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
|
148 |
total_num_microbatches = num_microbatches * num_stages_per_device
|
|
|
153 |
# stage ID (more forward passes for earlier stages, later stages can
|
154 |
# immediately start with 1F1B).
|
155 |
num_warmup_microbatches = (num_devices - device_id - 1) * 2
|
156 |
+
num_warmup_microbatches += (
|
157 |
+
num_stages_per_device - 1
|
158 |
+
) * microbatch_group_size_per_vp_stage
|
159 |
else:
|
160 |
# forward_backward_no_pipelining
|
161 |
num_warmup_microbatches = 1
|
|
|
166 |
return num_warmup_microbatches
|
167 |
|
168 |
|
169 |
+
def _get_schedule_table(
|
170 |
+
num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage
|
171 |
+
):
|
172 |
"""Get the schedule table for PP scheduling.
|
173 |
|
174 |
Create a tunable schedule lookup table.
|
175 |
+
The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
|
176 |
For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
|
177 |
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
178 |
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
179 |
+
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
180 |
"""
|
181 |
schedule_table = []
|
182 |
for min_microbatch_id_in_group in range(
|
183 |
0, num_microbatches, microbatch_group_size_per_vp_stage
|
184 |
):
|
185 |
+
if (
|
186 |
+
min_microbatch_id_in_group + microbatch_group_size_per_vp_stage
|
187 |
+
>= num_microbatches
|
188 |
+
):
|
189 |
# Construct schedule for the last microbatch group
|
190 |
schedule_table.extend(
|
191 |
[
|
192 |
(microbatch_id, model_chunk_id)
|
193 |
for model_chunk_id in range(num_model_chunks)
|
194 |
+
for microbatch_id in range(
|
195 |
+
min_microbatch_id_in_group, num_microbatches
|
196 |
+
)
|
197 |
]
|
198 |
)
|
199 |
else:
|
|
|
211 |
return schedule_table
|
212 |
|
213 |
|
214 |
+
def _convert_schedule_table_to_order(
|
215 |
+
num_warmup_microbatches, num_model_chunks, schedule_table
|
216 |
+
):
|
217 |
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
|
218 |
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
|
219 |
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
|
|
242 |
# Some codes are copied from Megatron-LM
|
243 |
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
244 |
schedule = Schedule(config)
|
245 |
+
|
246 |
for device_id in range(config.num_devices):
|
247 |
microbatch_group_size_per_vp_stage = config.num_devices
|
248 |
num_warmup_microbatches = _get_pp_rank_microbatches(
|
|
|
261 |
|
262 |
order = _convert_schedule_table_to_order(
|
263 |
num_warmup_microbatches,
|
264 |
+
num_model_chunks=config.num_stages_per_device,
|
265 |
schedule_table=schedule_table,
|
266 |
)
|
267 |
|
268 |
cur_stage_microbatch_id = {}
|
269 |
+
for i in range(1, config.num_stages_per_device + 1):
|
270 |
cur_stage_microbatch_id[i] = 0
|
271 |
cur_stage_microbatch_id[-i] = 0
|
272 |
for order_item in order:
|
273 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
|
274 |
|
275 |
if order_item > 0:
|
276 |
op_type = "forward"
|
277 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
278 |
+
cur_stage_microbatch_id[order_item] = (
|
279 |
+
cur_stage_microbatch_id[order_item] + 1
|
280 |
+
)
|
281 |
elif order_item < 0:
|
282 |
op_type = "backward"
|
283 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
284 |
+
cur_stage_microbatch_id[order_item] = (
|
285 |
+
cur_stage_microbatch_id[order_item] + 1
|
286 |
+
)
|
287 |
else:
|
288 |
raise ValueError(f"Invalid order item: {order_item}")
|
289 |
schedule.device_queues[device_id].add_operation(
|
|
|
291 |
)
|
292 |
return schedule
|
293 |
|
294 |
+
|
295 |
def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
296 |
schedule = Schedule(config)
|
297 |
|
|
|
312 |
)
|
313 |
|
314 |
# NOTE: Add one more warmup microbatch for overlapped operations!
|
315 |
+
num_warmup_microbatches += 1
|
316 |
order = _convert_schedule_table_to_order(
|
317 |
num_warmup_microbatches,
|
318 |
+
num_model_chunks=config.num_stages_per_device,
|
319 |
schedule_table=schedule_table,
|
320 |
)
|
321 |
|
322 |
cur_stage_microbatch_id = {}
|
323 |
+
for i in range(1, config.num_stages_per_device + 1):
|
324 |
cur_stage_microbatch_id[i] = 0
|
325 |
cur_stage_microbatch_id[-i] = 0
|
326 |
i = 0
|
|
|
332 |
assert order_item > 0
|
333 |
op_type = "forward"
|
334 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
335 |
+
cur_stage_microbatch_id[order_item] = (
|
336 |
+
cur_stage_microbatch_id[order_item] + 1
|
337 |
+
)
|
338 |
|
339 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
|
340 |
schedule.device_queues[device_id].add_operation(
|
341 |
schedule.get_op(micro_batch_id, stage_id, op_type)
|
342 |
)
|
343 |
i += 1
|
344 |
+
elif (
|
345 |
+
i >= num_warmup_microbatches
|
346 |
+
and i < num_warmup_microbatches + num_overlapped_batches - 1
|
347 |
+
):
|
348 |
order_item_a = order[i]
|
349 |
+
order_item_b = order[i + 1]
|
350 |
|
351 |
op_type_a = "forward" if order_item_a > 0 else "backward"
|
352 |
micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
|
353 |
+
cur_stage_microbatch_id[order_item_a] = (
|
354 |
+
cur_stage_microbatch_id[order_item_a] + 1
|
355 |
+
)
|
356 |
|
357 |
op_type_b = "forward" if order_item_b > 0 else "backward"
|
358 |
micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
|
359 |
+
cur_stage_microbatch_id[order_item_b] = (
|
360 |
+
cur_stage_microbatch_id[order_item_b] + 1
|
361 |
+
)
|
362 |
|
363 |
+
stage_id_a = schedule.device_queues[device_id].stages[
|
364 |
+
abs(order_item_a) - 1
|
365 |
+
]
|
366 |
+
stage_id_b = schedule.device_queues[device_id].stages[
|
367 |
+
abs(order_item_b) - 1
|
368 |
+
]
|
369 |
|
370 |
op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
|
371 |
op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
|
|
|
380 |
assert order_item < 0
|
381 |
op_type = "backward"
|
382 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
383 |
+
cur_stage_microbatch_id[order_item] = (
|
384 |
+
cur_stage_microbatch_id[order_item] + 1
|
385 |
+
)
|
386 |
|
387 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
|
388 |
schedule.device_queues[device_id].add_operation(
|
389 |
schedule.get_op(micro_batch_id, stage_id, op_type)
|
390 |
)
|
391 |
i += 1
|
|
|
392 |
|
393 |
return schedule
|
394 |
|
|
|
401 |
# Get the operations from the schedule
|
402 |
op1 = schedule.ops[(batch_id1, stage_id, type1)]
|
403 |
op2 = schedule.ops[(batch_id2, stage_id, type2)]
|
404 |
+
|
405 |
# Create the overlapped operation
|
406 |
overlapped_op = OverlappedOperation([op1, op2])
|
407 |
+
|
408 |
# Register in the schedule to ensure proper tracking
|
409 |
schedule.register_overlapped_operation(overlapped_op)
|
410 |
+
|
411 |
return overlapped_op
|
412 |
|
413 |
|
414 |
def generate_dualpipe_schedule(config: ScheduleConfig):
|
415 |
"""
|
416 |
Implements the DualPipe scheduling strategy.
|
417 |
+
|
418 |
DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
|
419 |
and backward computation-communication phases and reduces pipeline bubbles.
|
420 |
+
|
421 |
The DualPipe strategy has the following characteristics:
|
422 |
1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
|
423 |
2. Each device handles both a forward stage and a reverse stage
|
|
|
432 |
A Schedule object with the DualPipe scheduling
|
433 |
"""
|
434 |
# Ensure placement strategy is set for Schedule initialization
|
435 |
+
assert (
|
436 |
+
config.placement_strategy == "dualpipe"
|
437 |
+
), "DualPipe schedule currently only supports placement_strategy='dualpipe'"
|
438 |
# Assertions based on DualPipe requirements
|
439 |
+
assert (
|
440 |
+
config.num_stages % 2 == 0
|
441 |
+
), "DualPipe requires an even number of stages (and devices)"
|
442 |
+
assert (
|
443 |
+
config.num_devices == config.num_stages
|
444 |
+
), "DualPipe requires num_devices == num_stages"
|
445 |
+
assert (
|
446 |
+
config.num_batches % 2 == 0
|
447 |
+
), "DualPipe requires an even number of microbatches (config.num_batches)"
|
448 |
# Assertion based on original implementation: num_chunks >= num_ranks * 2
|
449 |
# Here, M (config.num_batches) corresponds to half_num_chunks
|
450 |
+
assert (
|
451 |
+
config.num_batches >= config.num_devices
|
452 |
+
), "DualPipe requires config.num_batches >= config.num_devices"
|
453 |
+
assert (
|
454 |
+
config.split_backward
|
455 |
+
), "DualPipe schedule currently only supports split_backward=True"
|
456 |
|
457 |
schedule = Schedule(config, init_ops=False)
|
458 |
|
|
|
462 |
half_num_chunks = config.num_batches // 2
|
463 |
num_half_ranks = num_devices // 2
|
464 |
|
465 |
+
fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
466 |
+
bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
467 |
|
468 |
+
waited_weight_grad = [
|
469 |
+
deque() for _ in range(num_devices)
|
470 |
+
] # (device_id, ) -> List[(stage_id, batch_id)]
|
471 |
|
472 |
for device_id in range(num_devices):
|
473 |
is_in_second_half = device_id >= num_half_ranks
|
|
|
481 |
fwd_batch_ids[device_id, 1] = config.num_batches // 2
|
482 |
bwd_d_batch_ids[device_id, 0] = 0
|
483 |
bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
|
484 |
+
|
485 |
def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
|
486 |
+
stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
|
487 |
+
stage_rev_dir = (
|
488 |
+
num_stages - 1 - device_id
|
489 |
+
) # Stage handled when moving backward (N-1 to 0)
|
490 |
if not is_in_second_half:
|
491 |
# First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
|
492 |
return stage_fwd_dir if phase == 0 else stage_rev_dir
|
493 |
else:
|
494 |
# Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
|
495 |
return stage_rev_dir if phase == 0 else stage_fwd_dir
|
|
|
496 |
|
497 |
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
498 |
# Retrieve the correct pre-initialized Operation object
|
|
|
514 |
batch_id = bwd_d_batch_ids[device_id, phase]
|
515 |
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
516 |
bwd_d_batch_ids[device_id, phase] += 1
|
517 |
+
|
518 |
def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
|
519 |
"""Schedules a backward_D compute operation."""
|
520 |
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
|
|
528 |
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
529 |
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
530 |
|
531 |
+
def _schedule_forward_backward_chunk(
|
532 |
+
device_id, fwd_phase, bwd_phase, is_in_second_half
|
533 |
+
):
|
534 |
"""Schedules an overlapped forward and backward_D compute operation."""
|
535 |
+
fwd_stage_id = get_stage_for_phase(
|
536 |
+
device_id, fwd_phase, num_stages, is_in_second_half
|
537 |
+
)
|
538 |
+
bwd_stage_id = get_stage_for_phase(
|
539 |
+
device_id, bwd_phase, num_stages, is_in_second_half
|
540 |
+
)
|
541 |
+
|
542 |
fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
|
543 |
|
544 |
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
|
|
551 |
bwd_d_batch_ids[device_id, bwd_phase] += 1
|
552 |
|
553 |
# Create and register the overlapped operation
|
554 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
555 |
schedule.register_overlapped_operation(overlapped_op)
|
556 |
+
|
557 |
# Add the overlapped operation to the queue
|
558 |
schedule.device_queues[device_id].add_operation(overlapped_op)
|
559 |
|
|
|
560 |
# Process each device (rank in original code)
|
561 |
for device_id in range(num_devices):
|
562 |
half_rank = min(device_id, num_devices - 1 - device_id)
|
563 |
is_in_second_half = device_id >= num_half_ranks
|
564 |
+
is_middle_rank = (device_id == num_half_ranks - 1) or (
|
565 |
+
device_id == num_half_ranks
|
566 |
+
)
|
567 |
|
568 |
# Map original steps to operation additions
|
569 |
# Step 1: nF0
|
570 |
step_1_count = (num_half_ranks - half_rank - 1) * 2
|
571 |
for _ in range(step_1_count):
|
572 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
573 |
|
574 |
# Step 2: nF0F1
|
575 |
step_2_count = half_rank + 1
|
576 |
for i in range(step_2_count):
|
577 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
578 |
+
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
579 |
|
580 |
# Step 3: nB1W1F1
|
581 |
step_3_count = num_half_ranks - half_rank - 1
|
582 |
for _ in range(step_3_count):
|
583 |
+
_schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
|
584 |
+
_schedule_backward_weight_chunk(
|
585 |
+
device_id,
|
586 |
+
) # W1
|
587 |
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
588 |
|
589 |
# Step 4 (Main step): nF0B1F1B0
|
590 |
step_4_count = half_num_chunks - num_devices + half_rank + 1
|
591 |
for i in range(step_4_count):
|
592 |
# if i == 0 and is_middle_rank:
|
593 |
+
# Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
|
594 |
+
# _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
595 |
+
# _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
|
596 |
+
# _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
|
597 |
# else:
|
598 |
# Overlap F0 and B1_D, then schedule W1
|
599 |
+
_schedule_forward_backward_chunk(
|
600 |
+
device_id, 0, 1, is_in_second_half
|
601 |
+
) # F0+B1
|
602 |
+
|
603 |
# Overlap F1 and B0_D, then schedule W0
|
604 |
+
_schedule_forward_backward_chunk(
|
605 |
+
device_id, 1, 0, is_in_second_half
|
606 |
+
) # F1+B0
|
607 |
|
608 |
# Step 5: nB1F1B0
|
609 |
step_5_count = num_half_ranks - half_rank - 1
|
610 |
for _ in range(step_5_count):
|
611 |
+
_schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
|
612 |
+
_schedule_forward_backward_chunk(
|
613 |
+
device_id, 1, 0, is_in_second_half
|
614 |
+
) # F1+B0
|
615 |
|
616 |
# Step 6: nB1B0
|
617 |
step_6_count = half_rank + 1
|
|
|
633 |
# Step 7: nWB0
|
634 |
step_7_count = num_half_ranks - half_rank - 1
|
635 |
for _ in range(step_7_count):
|
636 |
+
_schedule_backward_weight_chunk(
|
637 |
+
device_id
|
638 |
+
) # W1 (use gradient from B1_D scheduled previously)
|
639 |
+
_schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
|
640 |
|
641 |
# Step 8: nW
|
642 |
step_8_count = half_rank + 1
|
|
|
644 |
# W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
|
645 |
# W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
|
646 |
# The last W0 gradients correspond to B0_D from step 6 or 7.
|
647 |
+
_schedule_backward_weight_chunk(
|
648 |
+
device_id
|
649 |
+
) # W0 (use gradient from B0_D scheduled previously)
|
650 |
|
651 |
return schedule
|
652 |
|
653 |
+
|
654 |
+
def generate_dualpipe_v_schedule(config: ScheduleConfig):
|
655 |
+
"""
|
656 |
+
Implements the DualPipe-V scheduling strategy based on dualpipe_v.py.
|
657 |
+
|
658 |
+
DualPipe-V aims to improve upon DualPipe by utilizing Zero Bubble (ZB)
|
659 |
+
techniques, further reducing pipeline bubbles by overlapping gradient
|
660 |
+
computation (backward_D) and weight updates (backward_W).
|
661 |
+
|
662 |
+
Key characteristics:
|
663 |
+
1. Requires placement_strategy="dualpipe".
|
664 |
+
2. Each device handles a forward stage and a reverse stage.
|
665 |
+
3. Requires split_backward=True.
|
666 |
+
4. Overlaps forward (F) and backward_D (B_D) operations.
|
667 |
+
5. Schedules backward_W (W) operations separately.
|
668 |
+
6. Uses Zero Bubble logic in later steps to delay W operations.
|
669 |
+
7. Assumes config.num_batches corresponds to the total number of microbatches (`num_chunks` in dualpipe_v.py).
|
670 |
+
|
671 |
+
Args:
|
672 |
+
config: The scheduling configuration.
|
673 |
+
|
674 |
+
Returns:
|
675 |
+
A Schedule object with the DualPipe-V scheduling.
|
676 |
+
"""
|
677 |
+
schedule = Schedule(config, init_ops=False)
|
678 |
+
|
679 |
+
assert config.num_stages == config.num_devices * 2, "num_stages must be equal to num_devices * 2 for DualPipe-V"
|
680 |
+
assert config.split_backward, "DualPipe-V requires split_backward=True"
|
681 |
+
|
682 |
+
num_stages = config.num_stages
|
683 |
+
num_devices = config.num_devices
|
684 |
+
|
685 |
+
fwd_batch_ids = defaultdict(int) # (device_id, chunk_id) -> batch_id
|
686 |
+
bwd_d_batch_ids = defaultdict(int) # (device_id, chunk_id) -> batch_id
|
687 |
+
|
688 |
+
waited_weight_grad = [
|
689 |
+
deque() for _ in range(num_devices)
|
690 |
+
] # (device_id, ) -> List[(stage_id, batch_id)]
|
691 |
+
|
692 |
+
for device_id in range(num_devices):
|
693 |
+
fwd_batch_ids[device_id, 0] = 0
|
694 |
+
fwd_batch_ids[device_id, 1] = 0
|
695 |
+
bwd_d_batch_ids[device_id, 0] = 0
|
696 |
+
bwd_d_batch_ids[device_id, 1] = 0
|
697 |
+
|
698 |
+
|
699 |
+
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
700 |
+
# Retrieve the correct pre-initialized Operation object
|
701 |
+
op = Operation(batch_id, stage_id, op_type)
|
702 |
+
schedule.register_operation(op)
|
703 |
+
# Add to the device queue
|
704 |
+
schedule.device_queues[device_id].add_operation(op)
|
705 |
+
|
706 |
+
def get_stage_for_chunk(device_id, chunk_id):
|
707 |
+
if chunk_id == 0:
|
708 |
+
# Forward direction stage for this device
|
709 |
+
return device_id
|
710 |
+
else:
|
711 |
+
# Reverse direction stage for this device
|
712 |
+
return num_stages - 1 - device_id
|
713 |
+
|
714 |
+
def _schedule_forward_chunk(device_id, chunk_id):
|
715 |
+
"""Schedules a forward compute operation."""
|
716 |
+
stage_id = get_stage_for_chunk(device_id, chunk_id)
|
717 |
+
batch_id = fwd_batch_ids[device_id, chunk_id]
|
718 |
+
add_op_to_queue(device_id, stage_id, "forward", batch_id)
|
719 |
+
fwd_batch_ids[device_id, chunk_id] += 1
|
720 |
+
|
721 |
+
def _schedule_backward_chunk(device_id, chunk_id, enable_zb=False):
|
722 |
+
"""Schedules a backward_D compute operation."""
|
723 |
+
stage_id = get_stage_for_chunk(device_id, chunk_id)
|
724 |
+
batch_id = bwd_d_batch_ids[device_id, chunk_id]
|
725 |
+
if enable_zb:
|
726 |
+
add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
|
727 |
+
waited_weight_grad[device_id].append((stage_id, batch_id))
|
728 |
+
else:
|
729 |
+
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
730 |
+
bwd_d_batch_ids[device_id, chunk_id] += 1
|
731 |
+
|
732 |
+
def _schedule_backward_weight_chunk(device_id):
|
733 |
+
"""Schedules a backward_W compute operation."""
|
734 |
+
assert waited_weight_grad[device_id], f"Device {device_id} has no waited weight grads to schedule"
|
735 |
+
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
736 |
+
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
737 |
+
|
738 |
+
def _schedule_forward_backward_chunk(
|
739 |
+
device_id, fwd_chunk_id, bwd_chunk_id
|
740 |
+
):
|
741 |
+
"""Schedules an overlapped forward and backward_D compute operation."""
|
742 |
+
fwd_stage_id = get_stage_for_chunk(device_id, fwd_chunk_id)
|
743 |
+
bwd_stage_id = get_stage_for_chunk(device_id, bwd_chunk_id)
|
744 |
+
|
745 |
+
fwd_batch_id = fwd_batch_ids[device_id, fwd_chunk_id]
|
746 |
+
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
747 |
+
schedule.register_operation(fwd_op)
|
748 |
+
fwd_batch_ids[device_id, fwd_chunk_id] += 1
|
749 |
+
|
750 |
+
bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_chunk_id]
|
751 |
+
# Schedule backward_D
|
752 |
+
bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
|
753 |
+
schedule.register_operation(bwd_op)
|
754 |
+
bwd_d_batch_ids[device_id, bwd_chunk_id] += 1
|
755 |
+
|
756 |
+
# Create and register the overlapped operation
|
757 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
758 |
+
schedule.register_overlapped_operation(overlapped_op)
|
759 |
+
|
760 |
+
# Add the overlapped operation to the queue
|
761 |
+
schedule.device_queues[device_id].add_operation(overlapped_op)
|
762 |
+
|
763 |
+
# Process each device (rank in original code)
|
764 |
+
for device_id in range(num_devices):
|
765 |
+
# Step 1: nF0
|
766 |
+
step_1_count = (num_devices - device_id - 1) * 2
|
767 |
+
for _ in range(step_1_count):
|
768 |
+
_schedule_forward_chunk(device_id, 0) # F0
|
769 |
+
|
770 |
+
# Step 2: nF0F1
|
771 |
+
step_2_count = device_id + 1
|
772 |
+
for i in range(step_2_count):
|
773 |
+
_schedule_forward_chunk(device_id, 0) # F0
|
774 |
+
_schedule_forward_chunk(device_id, 1) # F1
|
775 |
+
|
776 |
+
# Step 3: nB1W1F1 (Use zero bubble for B1)
|
777 |
+
step_3_count = num_devices - device_id - 1
|
778 |
+
for _ in range(step_3_count):
|
779 |
+
_schedule_backward_chunk(device_id, 1, enable_zb=True) # B1_D (ZB enabled)
|
780 |
+
_schedule_backward_weight_chunk(device_id) # W1
|
781 |
+
_schedule_forward_chunk(device_id, 1) # F1
|
782 |
+
|
783 |
+
# Step 4 (Main step): nF0B1F1B0 (Overlapped F and B_D)
|
784 |
+
num_batches = config.num_batches
|
785 |
+
step_4_count = num_batches - num_devices * 2 + device_id + 1
|
786 |
+
is_last_rank = (device_id == num_devices - 1) # Check if it's the last rank
|
787 |
+
|
788 |
+
for i in range(step_4_count):
|
789 |
+
if i == 0:
|
790 |
+
if is_last_rank:
|
791 |
+
# Special handling for the first iteration on the last rank
|
792 |
+
# Schedule F0, B1, W1 sequentially
|
793 |
+
_schedule_forward_chunk(device_id, 0) # F0
|
794 |
+
_schedule_backward_chunk(device_id, 1, enable_zb=False) # B1_D
|
795 |
+
else:
|
796 |
+
# Overlap F0 and B1
|
797 |
+
_schedule_forward_backward_chunk(device_id, 0, 1) # F0 + B1_D
|
798 |
+
else:
|
799 |
+
# Overlap F1 and B0_D
|
800 |
+
_schedule_forward_backward_chunk(device_id, 0, 1) # F0B1
|
801 |
+
_schedule_forward_backward_chunk(device_id, 1, 0) #
|
802 |
+
|
803 |
+
|
804 |
+
# Step 5: nB1F1B0
|
805 |
+
step_5_count = num_devices - device_id - 1
|
806 |
+
for _ in range(step_5_count):
|
807 |
+
# Schedule B1 (B1_D + B1_W) sequentially
|
808 |
+
_schedule_backward_chunk(device_id, 1, enable_zb=False) # B1_D + W1
|
809 |
+
|
810 |
+
# Overlap F1 and B0
|
811 |
+
_schedule_forward_backward_chunk(device_id, 1, 0) # F1 + B0
|
812 |
+
|
813 |
+
# Step 6: nB1B0 (The second half of the chunks use zero bubble)
|
814 |
+
step_6_count = device_id + 1
|
815 |
+
enable_zb = False
|
816 |
+
for i in range(step_6_count):
|
817 |
+
# Determine if ZB should be enabled for B1
|
818 |
+
if i == step_6_count // 2 and device_id % 2 == 1:
|
819 |
+
enable_zb = True
|
820 |
+
_schedule_backward_chunk(device_id, 1, enable_zb=enable_zb) # B1_D
|
821 |
+
|
822 |
+
# Determine if ZB should be enabled for B0
|
823 |
+
# ZB is enabled after the midpoint check for B0
|
824 |
+
if i == step_6_count // 2 and device_id % 2 == 0:
|
825 |
+
enable_zb = True # Enable ZB for the rest, including B0
|
826 |
+
_schedule_backward_chunk(device_id, 0, enable_zb=enable_zb) # B0_D
|
827 |
+
|
828 |
+
# Step 7: nWB0 (Use zero bubble for B0)
|
829 |
+
step_7_count = num_devices - device_id - 1
|
830 |
+
for _ in range(step_7_count):
|
831 |
+
_schedule_backward_weight_chunk(device_id) # W1 (from ZB B1_D in Step 6 or Step 3)
|
832 |
+
_schedule_backward_chunk(device_id, 0, enable_zb=True) # B0_D
|
833 |
+
|
834 |
+
# Step 8: nW
|
835 |
+
step_8_count = device_id + 1
|
836 |
+
for _ in range(step_8_count):
|
837 |
+
_schedule_backward_weight_chunk(device_id) # W0 (from ZB B0_D in Step 6 or 7) or W1 (from ZB B1_D in Step 6)
|
838 |
+
|
839 |
+
# Final check: Ensure all waited gradients are processed
|
840 |
+
assert not waited_weight_grad[device_id], f"Device {device_id} has remaining waited weight grads: {waited_weight_grad[device_id]}"
|
841 |
+
|
842 |
+
|
843 |
+
return schedule
|