Spaces:
Running
Running
Add support for DualPipe.
Browse files- .gitignore +1 -0
- README.md +21 -6
- assets/dualpipe.png +3 -0
- conf/config.yaml +3 -0
- main.py +23 -0
- src/execution_model.py +81 -19
- src/strategies.py +227 -2
- src/visualizer.py +2 -12
.gitignore
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
uv.lock
|
4 |
outputs/
|
5 |
.cursor/*
|
|
|
6 |
|
7 |
# Uncomment below if you want to include these files
|
8 |
# !assets/*.png
|
|
|
3 |
uv.lock
|
4 |
outputs/
|
5 |
.cursor/*
|
6 |
+
*.json
|
7 |
|
8 |
# Uncomment below if you want to include these files
|
9 |
# !assets/*.png
|
README.md
CHANGED
@@ -18,6 +18,7 @@ Pipeline parallelism is a technique used to train large models by partitioning t
|
|
18 |
- Zero-Bubble 1F1B (ZB-1P)
|
19 |
- 1F1B with computation-communication overlap
|
20 |
- Interleaved 1F1B with computation-communication overlap
|
|
|
21 |
|
22 |
- **Visualization**:
|
23 |
- Interactive visualization dashboard using Plotly/Dash
|
@@ -56,6 +57,12 @@ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
|
|
56 |
```
|
57 |

|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
### Running for 1F1B-batch-overlap strategy:
|
60 |
```bash
|
61 |
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
@@ -68,10 +75,24 @@ uv run python main.py strategy=1f1b_interleave_overlap num_devices=4 num_stages=
|
|
68 |
```
|
69 |

|
70 |
|
|
|
71 |
## Configuration
|
72 |
|
73 |
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
### Using Different Configuration Files
|
76 |
|
77 |
You can use different configuration files with Hydra in several ways:
|
@@ -90,12 +111,6 @@ You can use different configuration files with Hydra in several ways:
|
|
90 |
uv run python main.py --config-name=model_A
|
91 |
```
|
92 |
|
93 |
-
#### Override Specific Parameters
|
94 |
-
|
95 |
-
You can also override specific parameters at runtime:
|
96 |
-
```bash
|
97 |
-
uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
|
98 |
-
```
|
99 |
|
100 |
## Project Structure
|
101 |
|
|
|
18 |
- Zero-Bubble 1F1B (ZB-1P)
|
19 |
- 1F1B with computation-communication overlap
|
20 |
- Interleaved 1F1B with computation-communication overlap
|
21 |
+
- DualPipe (Bidirectional pipeline parallelism with full forward-backward overlap)
|
22 |
|
23 |
- **Visualization**:
|
24 |
- Interactive visualization dashboard using Plotly/Dash
|
|
|
57 |
```
|
58 |

|
59 |
|
60 |
+
### Running for DualPipe strategy:
|
61 |
+
```bash
|
62 |
+
uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=20
|
63 |
+
```
|
64 |
+

|
65 |
+
|
66 |
### Running for 1F1B-batch-overlap strategy:
|
67 |
```bash
|
68 |
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
|
|
75 |
```
|
76 |

|
77 |
|
78 |
+
|
79 |
## Configuration
|
80 |
|
81 |
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
82 |
|
83 |
+
#### Override Specific Parameters
|
84 |
+
|
85 |
+
You can override specific parameters at runtime:
|
86 |
+
```bash
|
87 |
+
uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
|
88 |
+
```
|
89 |
+
|
90 |
+
Use DualPipe as an example, you can manually set different time for forward/backward/backward_D/backward_W/overlapped_forward_backward:
|
91 |
+
```bash
|
92 |
+
uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=32 op_times.forward=1.0 op_times.backward=2.0 op_times.backward_D=1.0 op_times.backward_W=1.0 op_times.overlapped_forward_backward=2.5
|
93 |
+
```
|
94 |
+
|
95 |
+
|
96 |
### Using Different Configuration Files
|
97 |
|
98 |
You can use different configuration files with Hydra in several ways:
|
|
|
111 |
uv run python main.py --config-name=model_A
|
112 |
```
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
## Project Structure
|
116 |
|
assets/dualpipe.png
ADDED
![]() |
Git LFS Details
|
conf/config.yaml
CHANGED
@@ -11,6 +11,9 @@ op_times:
|
|
11 |
# Option 1: Simple configuration (same time for all stages)
|
12 |
forward: 1.0
|
13 |
backward: 2.0
|
|
|
|
|
|
|
14 |
|
15 |
# Option 2: Commented example of stage-specific configuration
|
16 |
# forward:
|
|
|
11 |
# Option 1: Simple configuration (same time for all stages)
|
12 |
forward: 1.0
|
13 |
backward: 2.0
|
14 |
+
backward_D: 1.0
|
15 |
+
backward_W: 1.0
|
16 |
+
overlapped_forward_backward: 2.0
|
17 |
|
18 |
# Option 2: Commented example of stage-specific configuration
|
19 |
# forward:
|
main.py
CHANGED
@@ -5,6 +5,7 @@ from src.strategies import (
|
|
5 |
generate_1f1b_overlap_schedule,
|
6 |
generate_1f1b_schedule,
|
7 |
generate_zero_bubble_1p_schedule,
|
|
|
8 |
)
|
9 |
from src.visualizer import visualize_pipeline_parallelism_dash
|
10 |
import hydra
|
@@ -26,6 +27,8 @@ def main(cfg: DictConfig) -> None:
|
|
26 |
run_1f1b_overlap(cfg)
|
27 |
elif cfg.strategy == "1f1b_interleave_overlap":
|
28 |
run_1f1b_interleave_overlap(cfg)
|
|
|
|
|
29 |
else:
|
30 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
31 |
|
@@ -129,5 +132,25 @@ def run_1f1b_interleave_overlap(cfg: DictConfig) -> None:
|
|
129 |
schedule.execute()
|
130 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if __name__ == "__main__":
|
133 |
main()
|
|
|
5 |
generate_1f1b_overlap_schedule,
|
6 |
generate_1f1b_schedule,
|
7 |
generate_zero_bubble_1p_schedule,
|
8 |
+
generate_dualpipe_schedule,
|
9 |
)
|
10 |
from src.visualizer import visualize_pipeline_parallelism_dash
|
11 |
import hydra
|
|
|
27 |
run_1f1b_overlap(cfg)
|
28 |
elif cfg.strategy == "1f1b_interleave_overlap":
|
29 |
run_1f1b_interleave_overlap(cfg)
|
30 |
+
elif cfg.strategy == "dualpipe":
|
31 |
+
run_dualpipe(cfg)
|
32 |
else:
|
33 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
34 |
|
|
|
132 |
schedule.execute()
|
133 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
134 |
|
135 |
+
def run_dualpipe(cfg: DictConfig) -> None:
|
136 |
+
"""Run DualPipe pipeline parallelism simulation."""
|
137 |
+
# Convert OmegaConf to dict for op_times if it exists
|
138 |
+
op_times = (
|
139 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
140 |
+
)
|
141 |
+
|
142 |
+
schedule_config = ScheduleConfig(
|
143 |
+
num_devices=cfg.num_devices,
|
144 |
+
num_stages=cfg.num_stages,
|
145 |
+
num_batches=cfg.num_batches,
|
146 |
+
p2p_latency=cfg.p2p_latency,
|
147 |
+
op_times=op_times,
|
148 |
+
split_backward=True,
|
149 |
+
placement_strategy="dualpipe",
|
150 |
+
)
|
151 |
+
schedule = generate_dualpipe_schedule(schedule_config)
|
152 |
+
schedule.execute()
|
153 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
154 |
+
|
155 |
if __name__ == "__main__":
|
156 |
main()
|
src/execution_model.py
CHANGED
@@ -69,7 +69,7 @@ class DeviceQueue:
|
|
69 |
def add_operation(self, op: Operation):
|
70 |
assert op.stage_id in self.stages
|
71 |
self.ops.append(op)
|
72 |
-
assert op.device_id is None
|
73 |
op.device_id = self.device_id
|
74 |
|
75 |
|
@@ -97,6 +97,7 @@ class ScheduleConfig:
|
|
97 |
"forward": 1.0,
|
98 |
"backward_D": 1.0,
|
99 |
"backward_W": 1.0,
|
|
|
100 |
}
|
101 |
else:
|
102 |
self.op_times = {
|
@@ -128,9 +129,14 @@ class ScheduleConfig:
|
|
128 |
self.num_stages_per_device = num_stages // num_devices
|
129 |
|
130 |
self.init_device_to_stages()
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
def init_device_to_stages(self):
|
136 |
if self.placement_strategy == "standard":
|
@@ -145,14 +151,27 @@ class ScheduleConfig:
|
|
145 |
for i in range(self.num_stages):
|
146 |
device_to_put = i % self.num_devices
|
147 |
self.device_to_stages[device_to_put].append(i)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
else:
|
149 |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
150 |
|
151 |
def get_op_time(self, op_type: str, stage_id: int):
|
152 |
# For overlapped operations, extract the original operation types
|
153 |
if op_type.startswith("overlapped_"):
|
154 |
-
if op_type in self.op_times
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
else:
|
157 |
op_parts = op_type.split("_")[1:]
|
158 |
if len(op_parts) >= 2:
|
@@ -173,20 +192,25 @@ class ScheduleConfig:
|
|
173 |
|
174 |
|
175 |
class Schedule:
|
176 |
-
def __init__(self, config: ScheduleConfig):
|
177 |
self.ops = {} # (batch_id, stage_id, op_type) -> Operation
|
178 |
self.device_queues: List[DeviceQueue] = []
|
179 |
for dev_id in range(config.num_devices):
|
180 |
self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
|
181 |
self.config = config
|
182 |
|
183 |
-
|
|
|
184 |
self.op_to_overlapped = {}
|
185 |
|
186 |
def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
|
187 |
for op in overlapped_op.operations:
|
188 |
self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
189 |
self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
|
|
|
|
|
|
|
|
190 |
|
191 |
def init_operations(self):
|
192 |
op_types = ["forward", "backward"]
|
@@ -199,9 +223,12 @@ class Schedule:
|
|
199 |
batch_id, stage_id, op_type
|
200 |
)
|
201 |
|
202 |
-
def get_op(self, batch_id: int, stage_id: int, op_type: str):
|
203 |
if (batch_id, stage_id, op_type) in self.op_to_overlapped:
|
204 |
return self.op_to_overlapped[(batch_id, stage_id, op_type)]
|
|
|
|
|
|
|
205 |
return self.ops[(batch_id, stage_id, op_type)]
|
206 |
|
207 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
@@ -226,20 +253,55 @@ class Schedule:
|
|
226 |
if self.config.split_backward:
|
227 |
if op.op_type == "backward_D":
|
228 |
if op.stage_id < self.config.num_stages - 1:
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
)
|
234 |
-
)
|
235 |
elif op.op_type == "backward_W":
|
236 |
if op.stage_id < self.config.num_stages - 1:
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
)
|
242 |
-
)
|
243 |
else:
|
244 |
if op.op_type == "backward":
|
245 |
if op.stage_id < self.config.num_stages - 1:
|
|
|
69 |
def add_operation(self, op: Operation):
|
70 |
assert op.stage_id in self.stages
|
71 |
self.ops.append(op)
|
72 |
+
assert op.device_id is None, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already has a device id on {op.device_id}"
|
73 |
op.device_id = self.device_id
|
74 |
|
75 |
|
|
|
97 |
"forward": 1.0,
|
98 |
"backward_D": 1.0,
|
99 |
"backward_W": 1.0,
|
100 |
+
"backward": 2.0,
|
101 |
}
|
102 |
else:
|
103 |
self.op_times = {
|
|
|
129 |
self.num_stages_per_device = num_stages // num_devices
|
130 |
|
131 |
self.init_device_to_stages()
|
132 |
+
if self.placement_strategy == "dualpipe":
|
133 |
+
assert (
|
134 |
+
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages * 2
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
assert (
|
138 |
+
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
|
139 |
+
)
|
140 |
|
141 |
def init_device_to_stages(self):
|
142 |
if self.placement_strategy == "standard":
|
|
|
151 |
for i in range(self.num_stages):
|
152 |
device_to_put = i % self.num_devices
|
153 |
self.device_to_stages[device_to_put].append(i)
|
154 |
+
elif self.placement_strategy == "dualpipe":
|
155 |
+
# For DualPipe, each device has two stages
|
156 |
+
assert self.num_devices == self.num_stages, "DualPipe requires num_devices == num_stages"
|
157 |
+
assert self.num_devices % 2 == 0, "DualPipe requires an even number of devices"
|
158 |
+
self.device_to_stages = defaultdict(list)
|
159 |
+
for i in range(self.num_stages):
|
160 |
+
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
161 |
else:
|
162 |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
163 |
|
164 |
def get_op_time(self, op_type: str, stage_id: int):
|
165 |
# For overlapped operations, extract the original operation types
|
166 |
if op_type.startswith("overlapped_"):
|
167 |
+
if op_type in self.op_times:
|
168 |
+
if isinstance(self.op_times[op_type], dict):
|
169 |
+
if stage_id in self.op_times[op_type]:
|
170 |
+
return self.op_times[op_type][stage_id]
|
171 |
+
else:
|
172 |
+
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
|
173 |
+
else:
|
174 |
+
return self.op_times[op_type]
|
175 |
else:
|
176 |
op_parts = op_type.split("_")[1:]
|
177 |
if len(op_parts) >= 2:
|
|
|
192 |
|
193 |
|
194 |
class Schedule:
|
195 |
+
def __init__(self, config: ScheduleConfig, init_ops: bool = True):
|
196 |
self.ops = {} # (batch_id, stage_id, op_type) -> Operation
|
197 |
self.device_queues: List[DeviceQueue] = []
|
198 |
for dev_id in range(config.num_devices):
|
199 |
self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
|
200 |
self.config = config
|
201 |
|
202 |
+
if init_ops:
|
203 |
+
self.init_operations()
|
204 |
self.op_to_overlapped = {}
|
205 |
|
206 |
def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
|
207 |
for op in overlapped_op.operations:
|
208 |
self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
209 |
self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
210 |
+
|
211 |
+
def register_operation(self, op: Operation):
|
212 |
+
assert (op.batch_id, op.stage_id, op.op_type) not in self.ops, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already registered"
|
213 |
+
self.ops[(op.batch_id, op.stage_id, op.op_type)] = op
|
214 |
|
215 |
def init_operations(self):
|
216 |
op_types = ["forward", "backward"]
|
|
|
223 |
batch_id, stage_id, op_type
|
224 |
)
|
225 |
|
226 |
+
def get_op(self, batch_id: int, stage_id: int, op_type: str, allow_none=False):
|
227 |
if (batch_id, stage_id, op_type) in self.op_to_overlapped:
|
228 |
return self.op_to_overlapped[(batch_id, stage_id, op_type)]
|
229 |
+
if allow_none:
|
230 |
+
if (batch_id, stage_id, op_type) not in self.ops:
|
231 |
+
return None
|
232 |
return self.ops[(batch_id, stage_id, op_type)]
|
233 |
|
234 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
|
|
253 |
if self.config.split_backward:
|
254 |
if op.op_type == "backward_D":
|
255 |
if op.stage_id < self.config.num_stages - 1:
|
256 |
+
op_bwd_d = self.get_op(op.batch_id, op.stage_id + 1, "backward_D", allow_none=True)
|
257 |
+
if op_bwd_d is not None:
|
258 |
+
deps.append(
|
259 |
+
(
|
260 |
+
op_bwd_d,
|
261 |
+
self.config.p2p_latency,
|
262 |
+
)
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
deps.append(
|
266 |
+
(
|
267 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward"),
|
268 |
+
self.config.p2p_latency,
|
269 |
+
)
|
270 |
)
|
|
|
271 |
elif op.op_type == "backward_W":
|
272 |
if op.stage_id < self.config.num_stages - 1:
|
273 |
+
op_bwd_d = self.get_op(op.batch_id, op.stage_id, "backward_D", allow_none=True)
|
274 |
+
if op_bwd_d is not None:
|
275 |
+
deps.append(
|
276 |
+
(
|
277 |
+
op_bwd_d,
|
278 |
+
self.config.p2p_latency,
|
279 |
+
)
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
deps.append(
|
283 |
+
(
|
284 |
+
self.get_op(op.batch_id, op.stage_id, "backward"),
|
285 |
+
self.config.p2p_latency,
|
286 |
+
)
|
287 |
+
)
|
288 |
+
elif op.op_type == "backward":
|
289 |
+
if op.stage_id < self.config.num_stages - 1:
|
290 |
+
op_bwd = self.get_op(op.batch_id, op.stage_id + 1, "backward", allow_none=True)
|
291 |
+
if op_bwd is not None:
|
292 |
+
deps.append(
|
293 |
+
(
|
294 |
+
op_bwd,
|
295 |
+
self.config.p2p_latency,
|
296 |
+
)
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
deps.append(
|
300 |
+
(
|
301 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward_D"),
|
302 |
+
self.config.p2p_latency,
|
303 |
+
)
|
304 |
)
|
|
|
305 |
else:
|
306 |
if op.op_type == "backward":
|
307 |
if op.stage_id < self.config.num_stages - 1:
|
src/strategies.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
from collections import defaultdict
|
2 |
-
from src.execution_model import OverlappedOperation, Schedule, ScheduleConfig
|
3 |
|
4 |
|
5 |
def generate_1f1b_schedule(config: ScheduleConfig):
|
@@ -43,6 +43,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
43 |
schedule = Schedule(config)
|
44 |
total_batches = config.num_batches
|
45 |
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
|
|
|
46 |
|
47 |
for i in range(config.num_devices):
|
48 |
fwd_batch_id = 0
|
@@ -354,3 +355,227 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
|
354 |
|
355 |
|
356 |
return schedule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict, deque
|
2 |
+
from src.execution_model import OverlappedOperation, Operation, Schedule, ScheduleConfig
|
3 |
|
4 |
|
5 |
def generate_1f1b_schedule(config: ScheduleConfig):
|
|
|
43 |
schedule = Schedule(config)
|
44 |
total_batches = config.num_batches
|
45 |
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
|
46 |
+
assert config.split_backward, "ZB-1P requires split_backward=True"
|
47 |
|
48 |
for i in range(config.num_devices):
|
49 |
fwd_batch_id = 0
|
|
|
355 |
|
356 |
|
357 |
return schedule
|
358 |
+
|
359 |
+
|
360 |
+
def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2):
|
361 |
+
"""
|
362 |
+
Helper function to create overlapped operations correctly.
|
363 |
+
This handles the underlying operation creation and registration to avoid device_id issues.
|
364 |
+
"""
|
365 |
+
# Get the operations from the schedule
|
366 |
+
op1 = schedule.ops[(batch_id1, stage_id, type1)]
|
367 |
+
op2 = schedule.ops[(batch_id2, stage_id, type2)]
|
368 |
+
|
369 |
+
# Create the overlapped operation
|
370 |
+
overlapped_op = OverlappedOperation([op1, op2])
|
371 |
+
|
372 |
+
# Register in the schedule to ensure proper tracking
|
373 |
+
schedule.register_overlapped_operation(overlapped_op)
|
374 |
+
|
375 |
+
return overlapped_op
|
376 |
+
|
377 |
+
|
378 |
+
def generate_dualpipe_schedule(config: ScheduleConfig):
|
379 |
+
"""
|
380 |
+
Implements the DualPipe scheduling strategy.
|
381 |
+
|
382 |
+
DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
|
383 |
+
and backward computation-communication phases and reduces pipeline bubbles.
|
384 |
+
|
385 |
+
The DualPipe strategy has the following characteristics:
|
386 |
+
1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
|
387 |
+
2. Each device handles both a forward stage and a reverse stage
|
388 |
+
3. Overlaps forward and backward operations to reduce bubble size
|
389 |
+
4. Assumes config.num_batches corresponds to half the total microbatches in original paper (M).
|
390 |
+
5. Currently only supports split_backward=True.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
config: The scheduling configuration
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
A Schedule object with the DualPipe scheduling
|
397 |
+
"""
|
398 |
+
# Ensure placement strategy is set for Schedule initialization
|
399 |
+
assert config.placement_strategy == "dualpipe", "DualPipe schedule currently only supports placement_strategy='dualpipe'"
|
400 |
+
# Assertions based on DualPipe requirements
|
401 |
+
assert config.num_stages % 2 == 0, "DualPipe requires an even number of stages (and devices)"
|
402 |
+
assert config.num_devices == config.num_stages, "DualPipe requires num_devices == num_stages"
|
403 |
+
assert config.num_batches % 2 == 0, "DualPipe requires an even number of microbatches (config.num_batches)"
|
404 |
+
# Assertion based on original implementation: num_chunks >= num_ranks * 2
|
405 |
+
# Here, M (config.num_batches) corresponds to half_num_chunks
|
406 |
+
assert config.num_batches >= config.num_devices, "DualPipe requires config.num_batches >= config.num_devices"
|
407 |
+
assert config.split_backward, "DualPipe schedule currently only supports split_backward=True"
|
408 |
+
|
409 |
+
schedule = Schedule(config, init_ops=False)
|
410 |
+
|
411 |
+
num_stages = config.num_stages
|
412 |
+
num_devices = config.num_devices
|
413 |
+
# config.num_batches is M in the original paper, which corresponds to half_num_chunks
|
414 |
+
half_num_chunks = config.num_batches // 2
|
415 |
+
num_half_ranks = num_devices // 2
|
416 |
+
|
417 |
+
fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
418 |
+
bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
419 |
+
|
420 |
+
waited_weight_grad = [deque() for _ in range(num_devices)] # (device_id, ) -> List[(stage_id, batch_id)]
|
421 |
+
|
422 |
+
for device_id in range(num_devices):
|
423 |
+
is_in_second_half = device_id >= num_half_ranks
|
424 |
+
if is_in_second_half:
|
425 |
+
fwd_batch_ids[device_id, 1] = 0
|
426 |
+
fwd_batch_ids[device_id, 0] = config.num_batches // 2
|
427 |
+
bwd_d_batch_ids[device_id, 1] = 0
|
428 |
+
bwd_d_batch_ids[device_id, 0] = config.num_batches // 2
|
429 |
+
else:
|
430 |
+
fwd_batch_ids[device_id, 0] = 0
|
431 |
+
fwd_batch_ids[device_id, 1] = config.num_batches // 2
|
432 |
+
bwd_d_batch_ids[device_id, 0] = 0
|
433 |
+
bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
|
434 |
+
def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
|
435 |
+
stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
|
436 |
+
stage_rev_dir = num_stages - 1 - device_id # Stage handled when moving backward (N-1 to 0)
|
437 |
+
if not is_in_second_half:
|
438 |
+
# First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
|
439 |
+
return stage_fwd_dir if phase == 0 else stage_rev_dir
|
440 |
+
else:
|
441 |
+
# Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
|
442 |
+
return stage_rev_dir if phase == 0 else stage_fwd_dir
|
443 |
+
|
444 |
+
|
445 |
+
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
446 |
+
# Retrieve the correct pre-initialized Operation object
|
447 |
+
op = Operation(batch_id, stage_id, op_type)
|
448 |
+
schedule.register_operation(op)
|
449 |
+
# Add to the device queue
|
450 |
+
schedule.device_queues[device_id].add_operation(op)
|
451 |
+
|
452 |
+
def _schedule_forward_chunk(device_id, phase, is_in_second_half):
|
453 |
+
"""Schedules a forward compute operation."""
|
454 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
455 |
+
batch_id = fwd_batch_ids[device_id, phase]
|
456 |
+
add_op_to_queue(device_id, stage_id, "forward", batch_id)
|
457 |
+
fwd_batch_ids[device_id, phase] += 1
|
458 |
+
|
459 |
+
def _schedule_backward_chunk(device_id, phase, is_in_second_half):
|
460 |
+
"""Schedules a backward_D with backward_W compute operation."""
|
461 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
462 |
+
batch_id = bwd_d_batch_ids[device_id, phase]
|
463 |
+
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
464 |
+
bwd_d_batch_ids[device_id, phase] += 1
|
465 |
+
|
466 |
+
def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
|
467 |
+
"""Schedules a backward_D compute operation."""
|
468 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
469 |
+
batch_id = bwd_d_batch_ids[device_id, phase]
|
470 |
+
add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
|
471 |
+
bwd_d_batch_ids[device_id, phase] += 1
|
472 |
+
waited_weight_grad[device_id].append((stage_id, batch_id))
|
473 |
+
|
474 |
+
def _schedule_backward_weight_chunk(device_id):
|
475 |
+
"""Schedules a backward_W compute operation."""
|
476 |
+
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
477 |
+
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
478 |
+
|
479 |
+
def _schedule_forward_backward_chunk(device_id, fwd_phase, bwd_phase, is_in_second_half):
|
480 |
+
"""Schedules an overlapped forward and backward_D compute operation."""
|
481 |
+
fwd_stage_id = get_stage_for_phase(device_id, fwd_phase, num_stages, is_in_second_half)
|
482 |
+
bwd_stage_id = get_stage_for_phase(device_id, bwd_phase, num_stages, is_in_second_half)
|
483 |
+
|
484 |
+
fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
|
485 |
+
|
486 |
+
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
487 |
+
schedule.register_operation(fwd_op)
|
488 |
+
fwd_batch_ids[device_id, fwd_phase] += 1
|
489 |
+
|
490 |
+
bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_phase]
|
491 |
+
bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
|
492 |
+
schedule.register_operation(bwd_op)
|
493 |
+
bwd_d_batch_ids[device_id, bwd_phase] += 1
|
494 |
+
|
495 |
+
# Create and register the overlapped operation
|
496 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
497 |
+
schedule.register_overlapped_operation(overlapped_op)
|
498 |
+
|
499 |
+
# Add the overlapped operation to the queue
|
500 |
+
schedule.device_queues[device_id].add_operation(overlapped_op)
|
501 |
+
|
502 |
+
|
503 |
+
# Process each device (rank in original code)
|
504 |
+
for device_id in range(num_devices):
|
505 |
+
half_rank = min(device_id, num_devices - 1 - device_id)
|
506 |
+
is_in_second_half = device_id >= num_half_ranks
|
507 |
+
is_middle_rank = (device_id == num_half_ranks - 1) or (device_id == num_half_ranks)
|
508 |
+
|
509 |
+
# Map original steps to operation additions
|
510 |
+
# Step 1: nF0
|
511 |
+
step_1_count = (num_half_ranks - half_rank - 1) * 2
|
512 |
+
for _ in range(step_1_count):
|
513 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
514 |
+
|
515 |
+
# Step 2: nF0F1
|
516 |
+
step_2_count = half_rank + 1
|
517 |
+
for i in range(step_2_count):
|
518 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
519 |
+
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
520 |
+
|
521 |
+
# Step 3: nB1W1F1
|
522 |
+
step_3_count = num_half_ranks - half_rank - 1
|
523 |
+
for _ in range(step_3_count):
|
524 |
+
_schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
|
525 |
+
_schedule_backward_weight_chunk(device_id,) # W1
|
526 |
+
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
527 |
+
|
528 |
+
# Step 4 (Main step): nF0B1F1B0
|
529 |
+
step_4_count = half_num_chunks - num_devices + half_rank + 1
|
530 |
+
for i in range(step_4_count):
|
531 |
+
# if i == 0 and is_middle_rank:
|
532 |
+
# Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
|
533 |
+
# _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
534 |
+
# _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
|
535 |
+
# _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
|
536 |
+
# else:
|
537 |
+
# Overlap F0 and B1_D, then schedule W1
|
538 |
+
_schedule_forward_backward_chunk(device_id, 0, 1, is_in_second_half) # F0+B1
|
539 |
+
|
540 |
+
# Overlap F1 and B0_D, then schedule W0
|
541 |
+
_schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
|
542 |
+
|
543 |
+
# Step 5: nB1F1B0
|
544 |
+
step_5_count = num_half_ranks - half_rank - 1
|
545 |
+
for _ in range(step_5_count):
|
546 |
+
_schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
|
547 |
+
_schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
|
548 |
+
|
549 |
+
# Step 6: nB1B0
|
550 |
+
step_6_count = half_rank + 1
|
551 |
+
enable_zb = False
|
552 |
+
for i in range(step_6_count):
|
553 |
+
if i == step_6_count // 2 and half_rank % 2 == 1:
|
554 |
+
enable_zb = True
|
555 |
+
if enable_zb:
|
556 |
+
_schedule_backward_input_chunk(device_id, 1, is_in_second_half)
|
557 |
+
else:
|
558 |
+
_schedule_backward_chunk(device_id, 1, is_in_second_half)
|
559 |
+
if i == step_6_count // 2 and half_rank % 2 == 0:
|
560 |
+
enable_zb = True
|
561 |
+
if enable_zb:
|
562 |
+
_schedule_backward_input_chunk(device_id, 0, is_in_second_half)
|
563 |
+
else:
|
564 |
+
_schedule_backward_chunk(device_id, 0, is_in_second_half)
|
565 |
+
|
566 |
+
# Step 7: nWB0
|
567 |
+
step_7_count = num_half_ranks - half_rank - 1
|
568 |
+
for _ in range(step_7_count):
|
569 |
+
_schedule_backward_weight_chunk(device_id) # W1 (use gradient from B1_D scheduled previously)
|
570 |
+
_schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
|
571 |
+
|
572 |
+
# Step 8: nW
|
573 |
+
step_8_count = half_rank + 1
|
574 |
+
for _ in range(step_8_count):
|
575 |
+
# W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
|
576 |
+
# W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
|
577 |
+
# The last W0 gradients correspond to B0_D from step 6 or 7.
|
578 |
+
_schedule_backward_weight_chunk(device_id) # W0 (use gradient from B0_D scheduled previously)
|
579 |
+
|
580 |
+
return schedule
|
581 |
+
|
src/visualizer.py
CHANGED
@@ -89,11 +89,6 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
89 |
|
90 |
# Improved teal/turquoise palette with low saturation and high brightness
|
91 |
backward_d_colors = [
|
92 |
-
"#ccffff", # Very light cyan
|
93 |
-
"#b3ffff", # Pale cyan
|
94 |
-
"#99ffff", # Light cyan
|
95 |
-
"#80ffff", # Cyan
|
96 |
-
"#66e6e6", # Soft teal
|
97 |
"#4dcccc", # Light teal
|
98 |
"#33b3b3", # Teal
|
99 |
"#009999", # Medium teal
|
@@ -102,12 +97,6 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
102 |
|
103 |
# Improved green palette with low saturation and high brightness
|
104 |
backward_w_colors = [
|
105 |
-
"#ccffe6", # Very light mint
|
106 |
-
"#b3ffd9", # Pale mint
|
107 |
-
"#99ffcc", # Light mint
|
108 |
-
"#80ffbf", # Mint green
|
109 |
-
"#66e6a6", # Soft green
|
110 |
-
"#4dcc8c", # Light green
|
111 |
"#33b373", # Medium green
|
112 |
"#009959", # Forest green
|
113 |
"#008040", # Dark green
|
@@ -162,7 +151,8 @@ def create_pipeline_figure(
|
|
162 |
max_batch = max(max_batch, task["batch"])
|
163 |
|
164 |
# Flag to determine whether to show text labels
|
165 |
-
|
|
|
166 |
|
167 |
# Create a figure
|
168 |
fig = go.Figure()
|
|
|
89 |
|
90 |
# Improved teal/turquoise palette with low saturation and high brightness
|
91 |
backward_d_colors = [
|
|
|
|
|
|
|
|
|
|
|
92 |
"#4dcccc", # Light teal
|
93 |
"#33b3b3", # Teal
|
94 |
"#009999", # Medium teal
|
|
|
97 |
|
98 |
# Improved green palette with low saturation and high brightness
|
99 |
backward_w_colors = [
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
"#33b373", # Medium green
|
101 |
"#009959", # Forest green
|
102 |
"#008040", # Dark green
|
|
|
151 |
max_batch = max(max_batch, task["batch"])
|
152 |
|
153 |
# Flag to determine whether to show text labels
|
154 |
+
num_operations_per_device = len(schedule_data[0])
|
155 |
+
show_text_labels = num_operations_per_device <= 64
|
156 |
|
157 |
# Create a figure
|
158 |
fig = go.Figure()
|