Spaces:
Running
Running
Add 1F1B-overlap implementation.
Browse files- .gitignore +1 -0
- README.md +7 -0
- assets/1f1b.png +2 -2
- assets/1f1b_overlap.png +3 -0
- main.py +43 -10
- src/strategies.py +36 -0
- src/visualizer.py +225 -165
.gitignore
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
./venv
|
3 |
uv.lock
|
4 |
outputs/
|
|
|
5 |
|
6 |
# Uncomment below if you want to include these files
|
7 |
# !assets/*.png
|
|
|
2 |
./venv
|
3 |
uv.lock
|
4 |
outputs/
|
5 |
+
.cursor/*
|
6 |
|
7 |
# Uncomment below if you want to include these files
|
8 |
# !assets/*.png
|
README.md
CHANGED
@@ -50,6 +50,13 @@ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
|
|
50 |
```
|
51 |

|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
## Configuration
|
54 |
|
55 |
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
|
|
50 |
```
|
51 |

|
52 |
|
53 |
+
|
54 |
+
Running for 1F1B-batch-overlap strategy:
|
55 |
+
```bah
|
56 |
+
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
57 |
+
```
|
58 |
+

|
59 |
+
|
60 |
## Configuration
|
61 |
|
62 |
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
assets/1f1b.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
assets/1f1b_overlap.png
ADDED
![]() |
Git LFS Details
|
main.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
from src.execution_model import ScheduleConfig
|
2 |
-
from src.strategies import
|
|
|
|
|
|
|
|
|
|
|
3 |
from src.visualizer import visualize_pipeline_parallelism_dash
|
4 |
import hydra
|
5 |
from omegaconf import DictConfig, OmegaConf
|
@@ -16,6 +21,8 @@ def main(cfg: DictConfig) -> None:
|
|
16 |
run_interleave(cfg)
|
17 |
elif cfg.strategy == "zb1p":
|
18 |
run_zero_bubble_1p(cfg)
|
|
|
|
|
19 |
else:
|
20 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
21 |
|
@@ -23,7 +30,9 @@ def main(cfg: DictConfig) -> None:
|
|
23 |
def run_1f1b(cfg: DictConfig) -> None:
|
24 |
"""Run 1F1B pipeline parallelism simulation."""
|
25 |
# Convert OmegaConf to dict for op_times if it exists
|
26 |
-
op_times =
|
|
|
|
|
27 |
|
28 |
schedule_config = ScheduleConfig(
|
29 |
num_devices=cfg.num_devices,
|
@@ -31,7 +40,7 @@ def run_1f1b(cfg: DictConfig) -> None:
|
|
31 |
num_batches=cfg.num_batches,
|
32 |
p2p_latency=cfg.p2p_latency,
|
33 |
op_times=op_times,
|
34 |
-
placement_strategy="standard"
|
35 |
)
|
36 |
schedule = generate_1f1b_schedule(schedule_config)
|
37 |
schedule.execute()
|
@@ -42,15 +51,17 @@ def run_1f1b(cfg: DictConfig) -> None:
|
|
42 |
def run_interleave(cfg: DictConfig) -> None:
|
43 |
"""Run interleaved pipeline parallelism simulation."""
|
44 |
# Convert OmegaConf to dict for op_times if it exists
|
45 |
-
op_times =
|
46 |
-
|
|
|
|
|
47 |
schedule_config = ScheduleConfig(
|
48 |
num_devices=cfg.num_devices,
|
49 |
num_stages=cfg.num_stages,
|
50 |
num_batches=cfg.num_batches,
|
51 |
p2p_latency=cfg.p2p_latency,
|
52 |
placement_strategy="interleave",
|
53 |
-
op_times=op_times
|
54 |
)
|
55 |
schedule = generate_1f1b_interleave_schedule(schedule_config)
|
56 |
schedule.execute()
|
@@ -60,20 +71,42 @@ def run_interleave(cfg: DictConfig) -> None:
|
|
60 |
def run_zero_bubble_1p(cfg: DictConfig) -> None:
|
61 |
"""Run zero bubble 1P pipeline parallelism simulation."""
|
62 |
# Convert OmegaConf to dict for op_times if it exists
|
63 |
-
op_times =
|
64 |
-
|
|
|
|
|
65 |
schedule_config = ScheduleConfig(
|
66 |
num_devices=cfg.num_devices,
|
67 |
num_stages=cfg.num_stages,
|
68 |
num_batches=cfg.num_batches,
|
69 |
p2p_latency=cfg.p2p_latency,
|
70 |
op_times=op_times,
|
71 |
-
split_backward=True
|
72 |
)
|
73 |
schedule = generate_zero_bubble_1p_schedule(schedule_config)
|
74 |
schedule.execute()
|
75 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
if __name__ == "__main__":
|
79 |
-
main()
|
|
|
1 |
from src.execution_model import ScheduleConfig
|
2 |
+
from src.strategies import (
|
3 |
+
generate_1f1b_interleave_schedule,
|
4 |
+
generate_1f1b_overlap_schedule,
|
5 |
+
generate_1f1b_schedule,
|
6 |
+
generate_zero_bubble_1p_schedule,
|
7 |
+
)
|
8 |
from src.visualizer import visualize_pipeline_parallelism_dash
|
9 |
import hydra
|
10 |
from omegaconf import DictConfig, OmegaConf
|
|
|
21 |
run_interleave(cfg)
|
22 |
elif cfg.strategy == "zb1p":
|
23 |
run_zero_bubble_1p(cfg)
|
24 |
+
elif cfg.strategy == "1f1b_overlap":
|
25 |
+
run_1f1b_overlap(cfg)
|
26 |
else:
|
27 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
28 |
|
|
|
30 |
def run_1f1b(cfg: DictConfig) -> None:
|
31 |
"""Run 1F1B pipeline parallelism simulation."""
|
32 |
# Convert OmegaConf to dict for op_times if it exists
|
33 |
+
op_times = (
|
34 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
35 |
+
)
|
36 |
|
37 |
schedule_config = ScheduleConfig(
|
38 |
num_devices=cfg.num_devices,
|
|
|
40 |
num_batches=cfg.num_batches,
|
41 |
p2p_latency=cfg.p2p_latency,
|
42 |
op_times=op_times,
|
43 |
+
placement_strategy="standard",
|
44 |
)
|
45 |
schedule = generate_1f1b_schedule(schedule_config)
|
46 |
schedule.execute()
|
|
|
51 |
def run_interleave(cfg: DictConfig) -> None:
|
52 |
"""Run interleaved pipeline parallelism simulation."""
|
53 |
# Convert OmegaConf to dict for op_times if it exists
|
54 |
+
op_times = (
|
55 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
56 |
+
)
|
57 |
+
|
58 |
schedule_config = ScheduleConfig(
|
59 |
num_devices=cfg.num_devices,
|
60 |
num_stages=cfg.num_stages,
|
61 |
num_batches=cfg.num_batches,
|
62 |
p2p_latency=cfg.p2p_latency,
|
63 |
placement_strategy="interleave",
|
64 |
+
op_times=op_times,
|
65 |
)
|
66 |
schedule = generate_1f1b_interleave_schedule(schedule_config)
|
67 |
schedule.execute()
|
|
|
71 |
def run_zero_bubble_1p(cfg: DictConfig) -> None:
|
72 |
"""Run zero bubble 1P pipeline parallelism simulation."""
|
73 |
# Convert OmegaConf to dict for op_times if it exists
|
74 |
+
op_times = (
|
75 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
76 |
+
)
|
77 |
+
|
78 |
schedule_config = ScheduleConfig(
|
79 |
num_devices=cfg.num_devices,
|
80 |
num_stages=cfg.num_stages,
|
81 |
num_batches=cfg.num_batches,
|
82 |
p2p_latency=cfg.p2p_latency,
|
83 |
op_times=op_times,
|
84 |
+
split_backward=True,
|
85 |
)
|
86 |
schedule = generate_zero_bubble_1p_schedule(schedule_config)
|
87 |
schedule.execute()
|
88 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
89 |
|
90 |
|
91 |
+
def run_1f1b_overlap(cfg: DictConfig) -> None:
|
92 |
+
"""Run 1F1B overlap pipeline parallelism simulation."""
|
93 |
+
# Convert OmegaConf to dict for op_times if it exists
|
94 |
+
op_times = (
|
95 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
96 |
+
)
|
97 |
+
|
98 |
+
schedule_config = ScheduleConfig(
|
99 |
+
num_devices=cfg.num_devices,
|
100 |
+
num_stages=cfg.num_stages,
|
101 |
+
num_batches=cfg.num_batches,
|
102 |
+
p2p_latency=cfg.p2p_latency,
|
103 |
+
op_times=op_times,
|
104 |
+
split_backward=False,
|
105 |
+
)
|
106 |
+
schedule = generate_1f1b_overlap_schedule(schedule_config)
|
107 |
+
schedule.execute()
|
108 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
109 |
+
|
110 |
+
|
111 |
if __name__ == "__main__":
|
112 |
+
main()
|
src/strategies.py
CHANGED
@@ -94,6 +94,42 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
94 |
return schedule
|
95 |
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
# Some codes are copied from Megatron-LM
|
98 |
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
99 |
schedule = Schedule(config)
|
|
|
94 |
return schedule
|
95 |
|
96 |
|
97 |
+
def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
98 |
+
schedule = Schedule(config)
|
99 |
+
|
100 |
+
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
|
101 |
+
|
102 |
+
for i in range(config.num_devices):
|
103 |
+
fwd_batch_id = 0
|
104 |
+
bwd_batch_id = 0
|
105 |
+
cooldown_batches = warmup_batches = 2 * (config.num_devices - i - 1) + 1
|
106 |
+
steady_batches = config.num_batches - warmup_batches
|
107 |
+
|
108 |
+
for _ in range(warmup_batches):
|
109 |
+
schedule.dev_queues[i].add_operation(
|
110 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
111 |
+
)
|
112 |
+
fwd_batch_id += 1
|
113 |
+
|
114 |
+
for _ in range(steady_batches):
|
115 |
+
schedule.dev_queues[i].add_operation(
|
116 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
117 |
+
)
|
118 |
+
fwd_batch_id += 1
|
119 |
+
schedule.dev_queues[i].add_operation(
|
120 |
+
schedule.get_op(bwd_batch_id, i, "backward")
|
121 |
+
)
|
122 |
+
bwd_batch_id += 1
|
123 |
+
|
124 |
+
for _ in range(cooldown_batches):
|
125 |
+
schedule.dev_queues[i].add_operation(
|
126 |
+
schedule.get_op(bwd_batch_id, i, "backward")
|
127 |
+
)
|
128 |
+
bwd_batch_id += 1
|
129 |
+
|
130 |
+
return schedule
|
131 |
+
|
132 |
+
|
133 |
# Some codes are copied from Megatron-LM
|
134 |
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
135 |
schedule = Schedule(config)
|
src/visualizer.py
CHANGED
@@ -12,30 +12,34 @@ from src.execution_model import Schedule
|
|
12 |
def convert_schedule_to_visualization_format(schedule: Schedule):
|
13 |
"""
|
14 |
Converts a Schedule object to the format needed for visualization.
|
15 |
-
|
16 |
Returns:
|
17 |
Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
|
18 |
"""
|
19 |
# Make sure all operations have start and end times
|
20 |
for op in schedule.ops.values():
|
21 |
if op.start_time is None or op.end_time is None:
|
22 |
-
raise ValueError(
|
23 |
-
|
|
|
|
|
24 |
visualization_data = {}
|
25 |
-
|
26 |
# Organize operations by device
|
27 |
for device_id, device_queue in enumerate(schedule.dev_queues):
|
28 |
visualization_data[device_id] = []
|
29 |
-
|
30 |
for op in device_queue.ops:
|
31 |
-
visualization_data[device_id].append(
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
return visualization_data
|
40 |
|
41 |
|
@@ -44,58 +48,58 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
44 |
def get_color(op_type: str, stage_id: int, num_devices: int):
|
45 |
# A more harmonious blue palette with better progression for forward operations
|
46 |
forward_colors = [
|
47 |
-
"#5c88f2",
|
48 |
-
"#1a53ff",
|
49 |
-
"#b3c6ff",
|
50 |
-
"#4d79ff",
|
51 |
-
"#809fff",
|
52 |
-
"#0039e6",
|
53 |
-
"#002db3",
|
54 |
-
"#264db3",
|
55 |
-
"#7094db",
|
56 |
-
"#99b3e6"
|
57 |
]
|
58 |
-
|
59 |
# Orange palette for backward operations
|
60 |
backward_colors = [
|
61 |
-
"#ff9933",
|
62 |
-
"#ffad5c",
|
63 |
-
"#ffc285",
|
64 |
-
"#ffd6ad",
|
65 |
-
"#ff8000",
|
66 |
-
"#cc6600",
|
67 |
-
"#ff9933",
|
68 |
-
"#ffb366",
|
69 |
-
"#cc9966",
|
70 |
-
"#ffd699"
|
71 |
]
|
72 |
-
|
73 |
# Improved teal/turquoise palette with better progression for backward_D operations
|
74 |
backward_d_colors = [
|
75 |
-
"#80ffff",
|
76 |
-
"#00cccc",
|
77 |
-
"#00e6e6",
|
78 |
-
"#33ffff",
|
79 |
-
"#00b3b3",
|
80 |
-
"#008080",
|
81 |
-
"#00e6cc",
|
82 |
-
"#4ddbbd",
|
83 |
-
"#80d4c8",
|
84 |
-
"#b3e6e0"
|
85 |
]
|
86 |
-
|
87 |
# Improved green palette with better progression for backward_W operations
|
88 |
backward_w_colors = [
|
89 |
-
"#00cc66",
|
90 |
-
"#00e673",
|
91 |
-
"#33ff99",
|
92 |
-
"#80ffbf",
|
93 |
-
"#009933",
|
94 |
-
"#006622",
|
95 |
-
"#33cc33",
|
96 |
-
"#66cc66",
|
97 |
-
"#99cc99",
|
98 |
-
"#c6e6c6"
|
99 |
]
|
100 |
|
101 |
virtual_stage = stage_id // num_devices
|
@@ -115,7 +119,9 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
115 |
raise ValueError(f"Invalid operation type: {op_type}")
|
116 |
|
117 |
|
118 |
-
def create_pipeline_figure(
|
|
|
|
|
119 |
"""
|
120 |
Create a Plotly figure for pipeline parallelism scheduling.
|
121 |
|
@@ -126,9 +132,9 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
126 |
"""
|
127 |
# Find the number of devices
|
128 |
num_devices = len(schedule_data)
|
129 |
-
|
130 |
empty_color = "whitesmoke"
|
131 |
-
|
132 |
# Find the maximum time in the schedule if not provided
|
133 |
if max_time is None:
|
134 |
max_time = 0
|
@@ -146,7 +152,9 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
146 |
tasks_processed = 0
|
147 |
|
148 |
if show_progress:
|
149 |
-
progress_bar = tqdm(
|
|
|
|
|
150 |
|
151 |
# Create a custom y-axis with no gaps between devices
|
152 |
y_spacing = 1.0 # Use 1.0 for no gaps
|
@@ -159,7 +167,7 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
159 |
# Add rectangles for each task
|
160 |
for device_idx, device in enumerate(schedule_data):
|
161 |
device_idx_reversed = num_devices - device_idx - 1
|
162 |
-
|
163 |
# Sort tasks by start time to ensure correct rendering
|
164 |
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
165 |
|
@@ -189,44 +197,50 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
189 |
# Add rectangle for the task
|
190 |
start_time = task["start_time"]
|
191 |
duration = task["duration"]
|
192 |
-
|
193 |
# Calculate y positions with no gaps
|
194 |
y_pos = device_idx_reversed * y_spacing
|
195 |
-
|
196 |
# Create rectangle using shape (batch-add later)
|
197 |
-
shapes.append(
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
208 |
# Add batch number text (batch-add later)
|
209 |
-
annotations.append(
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
217 |
# Prepare hover data (add traces in batches later)
|
218 |
hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
|
219 |
-
|
220 |
-
hover_traces.append(
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
230 |
# Update progress
|
231 |
if show_progress:
|
232 |
tasks_processed += 1
|
@@ -234,63 +248,83 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
234 |
|
235 |
# Add all shapes at once for better performance
|
236 |
fig.update_layout(shapes=shapes)
|
237 |
-
|
238 |
# Add all annotations at once
|
239 |
fig.update_layout(annotations=annotations)
|
240 |
-
|
241 |
# Add all hover traces at once
|
242 |
for trace in hover_traces:
|
243 |
fig.add_trace(go.Scatter(**trace))
|
244 |
|
245 |
# Add custom legend
|
246 |
legend_items = []
|
247 |
-
|
248 |
# Find the maximum virtual stage in the data
|
249 |
max_virtual_stage = 0
|
250 |
for device in schedule_data:
|
251 |
for task in schedule_data[device]:
|
252 |
virtual_stage = task["stage"] // num_devices
|
253 |
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
254 |
-
|
255 |
# Add forward and backward items for each virtual stage
|
256 |
for vs in range(max_virtual_stage + 1):
|
257 |
-
legend_items.append(
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
265 |
# Add entries for split backward operations if this is a zb1p schedule
|
266 |
-
if any(
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
legend_items.append(
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
# If no tasks found, add default legend items
|
277 |
if not legend_items:
|
278 |
legend_items = [
|
279 |
dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
|
280 |
dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
|
281 |
-
dict(
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
]
|
284 |
-
|
285 |
for i, item in enumerate(legend_items):
|
286 |
-
fig.add_trace(
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
294 |
if show_progress and i < len(legend_items) - 1:
|
295 |
progress_bar.update(1)
|
296 |
|
@@ -299,11 +333,15 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
299 |
# Modify the ordering to put Device 1 at the top, then Device 0, then the rest
|
300 |
if num_devices >= 2:
|
301 |
# Move Device 1 to the top, followed by Device 0
|
302 |
-
device_labels =
|
303 |
-
|
|
|
|
|
|
|
|
|
304 |
# Calculate tick positions with no gaps
|
305 |
tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
|
306 |
-
|
307 |
# Adjust the range to ensure there are no empty spaces at the end
|
308 |
x_end = max_time * 1.05 # Add a small margin
|
309 |
|
@@ -323,17 +361,17 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
323 |
text=title_text,
|
324 |
x=0.5,
|
325 |
y=0.98, # Move title position closer to the top
|
326 |
-
font=dict(size=20)
|
327 |
),
|
328 |
legend=dict(
|
329 |
orientation="v", # Changed from horizontal to vertical
|
330 |
yanchor="top",
|
331 |
y=1.02, # Position at the top
|
332 |
xanchor="right",
|
333 |
-
x=1.20,
|
334 |
title=dict(text="<b>Operation Types:</b>"),
|
335 |
itemsizing="constant",
|
336 |
-
tracegroupgap=0
|
337 |
),
|
338 |
width=2000, # Increase width to accommodate the expanded legend
|
339 |
height=400, # Maintain current height
|
@@ -351,10 +389,13 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
351 |
# Cache for storing processed schedule data
|
352 |
_schedule_data_cache = {}
|
353 |
|
354 |
-
|
|
|
|
|
|
|
355 |
"""
|
356 |
Create a Dash app to visualize the pipeline schedule.
|
357 |
-
|
358 |
Args:
|
359 |
schedule: Schedule object to visualize
|
360 |
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
@@ -363,7 +404,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
|
|
363 |
# Process schedule data only once and cache it
|
364 |
global _schedule_data_cache
|
365 |
cache_key = id(schedule)
|
366 |
-
|
367 |
if enable_caching and cache_key in _schedule_data_cache:
|
368 |
schedule_data = _schedule_data_cache[cache_key]
|
369 |
print("Using cached schedule data")
|
@@ -372,7 +413,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
|
|
372 |
if enable_caching:
|
373 |
_schedule_data_cache[cache_key] = schedule_data
|
374 |
print("Cached schedule data")
|
375 |
-
|
376 |
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
|
377 |
print(f"Total tasks in schedule: {total_tasks}")
|
378 |
|
@@ -380,31 +421,48 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
|
|
380 |
app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
|
381 |
|
382 |
# Create a more informative layout with data size information
|
383 |
-
app.layout = html.Div(
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
# Cache for storing figure to avoid regenerating it
|
406 |
figure_cache = {}
|
407 |
-
|
408 |
@app.callback(
|
409 |
Output("pipeline-graph", "figure"),
|
410 |
Input("graph-container", "children"),
|
@@ -416,15 +474,15 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
|
|
416 |
if enable_caching and cache_key in figure_cache:
|
417 |
print("Using cached figure")
|
418 |
return figure_cache[cache_key]
|
419 |
-
|
420 |
# Create the figure
|
421 |
figure = create_pipeline_figure(schedule_data, show_progress=True)
|
422 |
-
|
423 |
# Cache the figure
|
424 |
if enable_caching:
|
425 |
figure_cache[cache_key] = figure
|
426 |
print("Cached figure")
|
427 |
-
|
428 |
return figure
|
429 |
|
430 |
return app
|
@@ -435,11 +493,11 @@ def visualize_pipeline_parallelism_dash(
|
|
435 |
port: int = 8050,
|
436 |
debug: bool = False,
|
437 |
enable_caching: bool = True,
|
438 |
-
schedule_type="1f1b"
|
439 |
):
|
440 |
"""
|
441 |
Launch a Dash app to visualize the pipeline schedule interactively.
|
442 |
-
|
443 |
Args:
|
444 |
schedule: Schedule object to visualize
|
445 |
port: Port to run the Dash app on
|
@@ -447,6 +505,8 @@ def visualize_pipeline_parallelism_dash(
|
|
447 |
enable_caching: Whether to cache schedule data and figures
|
448 |
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
449 |
"""
|
450 |
-
app = create_dash_app(
|
|
|
|
|
451 |
print(f"Starting Dash app on http://localhost:{port}/")
|
452 |
app.run_server(debug=debug, port=port)
|
|
|
12 |
def convert_schedule_to_visualization_format(schedule: Schedule):
|
13 |
"""
|
14 |
Converts a Schedule object to the format needed for visualization.
|
15 |
+
|
16 |
Returns:
|
17 |
Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
|
18 |
"""
|
19 |
# Make sure all operations have start and end times
|
20 |
for op in schedule.ops.values():
|
21 |
if op.start_time is None or op.end_time is None:
|
22 |
+
raise ValueError(
|
23 |
+
"Operations must have start and end times. Run ScheduleExecutor.execute() first."
|
24 |
+
)
|
25 |
+
|
26 |
visualization_data = {}
|
27 |
+
|
28 |
# Organize operations by device
|
29 |
for device_id, device_queue in enumerate(schedule.dev_queues):
|
30 |
visualization_data[device_id] = []
|
31 |
+
|
32 |
for op in device_queue.ops:
|
33 |
+
visualization_data[device_id].append(
|
34 |
+
{
|
35 |
+
"type": op.op_type,
|
36 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
37 |
+
"stage": op.stage_id,
|
38 |
+
"start_time": op.start_time,
|
39 |
+
"duration": op.end_time - op.start_time,
|
40 |
+
}
|
41 |
+
)
|
42 |
+
|
43 |
return visualization_data
|
44 |
|
45 |
|
|
|
48 |
def get_color(op_type: str, stage_id: int, num_devices: int):
|
49 |
# A more harmonious blue palette with better progression for forward operations
|
50 |
forward_colors = [
|
51 |
+
"#5c88f2", # Periwinkle blue
|
52 |
+
"#1a53ff", # Deep blue
|
53 |
+
"#b3c6ff", # Light blue
|
54 |
+
"#4d79ff", # Strong blue
|
55 |
+
"#809fff", # Medium blue
|
56 |
+
"#0039e6", # Rich navy
|
57 |
+
"#002db3", # Dark navy
|
58 |
+
"#264db3", # Royal blue
|
59 |
+
"#7094db", # Steel blue
|
60 |
+
"#99b3e6", # Pale blue
|
61 |
]
|
62 |
+
|
63 |
# Orange palette for backward operations
|
64 |
backward_colors = [
|
65 |
+
"#ff9933", # Bright orange
|
66 |
+
"#ffad5c", # Medium orange
|
67 |
+
"#ffc285", # Light orange
|
68 |
+
"#ffd6ad", # Pale orange
|
69 |
+
"#ff8000", # Deep orange
|
70 |
+
"#cc6600", # Dark orange
|
71 |
+
"#ff9933", # Vivid orange
|
72 |
+
"#ffb366", # Soft orange
|
73 |
+
"#cc9966", # Muted orange
|
74 |
+
"#ffd699", # Light amber
|
75 |
]
|
76 |
+
|
77 |
# Improved teal/turquoise palette with better progression for backward_D operations
|
78 |
backward_d_colors = [
|
79 |
+
"#80ffff", # Light cyan
|
80 |
+
"#00cccc", # Teal
|
81 |
+
"#00e6e6", # Bright teal
|
82 |
+
"#33ffff", # Cyan
|
83 |
+
"#00b3b3", # Medium teal
|
84 |
+
"#008080", # Dark teal
|
85 |
+
"#00e6cc", # Turquoise
|
86 |
+
"#4ddbbd", # Aqua
|
87 |
+
"#80d4c8", # Pale teal
|
88 |
+
"#b3e6e0", # Ice
|
89 |
]
|
90 |
+
|
91 |
# Improved green palette with better progression for backward_W operations
|
92 |
backward_w_colors = [
|
93 |
+
"#00cc66", # Medium green
|
94 |
+
"#00e673", # Bright green
|
95 |
+
"#33ff99", # Mint green
|
96 |
+
"#80ffbf", # Light green
|
97 |
+
"#009933", # Forest green
|
98 |
+
"#006622", # Dark green
|
99 |
+
"#33cc33", # True green
|
100 |
+
"#66cc66", # Sage green
|
101 |
+
"#99cc99", # Pale green
|
102 |
+
"#c6e6c6", # Pastel green
|
103 |
]
|
104 |
|
105 |
virtual_stage = stage_id // num_devices
|
|
|
119 |
raise ValueError(f"Invalid operation type: {op_type}")
|
120 |
|
121 |
|
122 |
+
def create_pipeline_figure(
|
123 |
+
schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True
|
124 |
+
):
|
125 |
"""
|
126 |
Create a Plotly figure for pipeline parallelism scheduling.
|
127 |
|
|
|
132 |
"""
|
133 |
# Find the number of devices
|
134 |
num_devices = len(schedule_data)
|
135 |
+
|
136 |
empty_color = "whitesmoke"
|
137 |
+
|
138 |
# Find the maximum time in the schedule if not provided
|
139 |
if max_time is None:
|
140 |
max_time = 0
|
|
|
152 |
tasks_processed = 0
|
153 |
|
154 |
if show_progress:
|
155 |
+
progress_bar = tqdm(
|
156 |
+
total=total_tasks + num_devices + 3, desc="Creating visualization"
|
157 |
+
)
|
158 |
|
159 |
# Create a custom y-axis with no gaps between devices
|
160 |
y_spacing = 1.0 # Use 1.0 for no gaps
|
|
|
167 |
# Add rectangles for each task
|
168 |
for device_idx, device in enumerate(schedule_data):
|
169 |
device_idx_reversed = num_devices - device_idx - 1
|
170 |
+
|
171 |
# Sort tasks by start time to ensure correct rendering
|
172 |
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
173 |
|
|
|
197 |
# Add rectangle for the task
|
198 |
start_time = task["start_time"]
|
199 |
duration = task["duration"]
|
200 |
+
|
201 |
# Calculate y positions with no gaps
|
202 |
y_pos = device_idx_reversed * y_spacing
|
203 |
+
|
204 |
# Create rectangle using shape (batch-add later)
|
205 |
+
shapes.append(
|
206 |
+
dict(
|
207 |
+
type="rect",
|
208 |
+
x0=start_time,
|
209 |
+
y0=y_pos - 0.5,
|
210 |
+
x1=start_time + duration,
|
211 |
+
y1=y_pos + 0.5,
|
212 |
+
line=dict(color="black", width=0.5),
|
213 |
+
fillcolor=color,
|
214 |
+
layer="above",
|
215 |
+
)
|
216 |
+
)
|
217 |
+
|
218 |
# Add batch number text (batch-add later)
|
219 |
+
annotations.append(
|
220 |
+
dict(
|
221 |
+
x=start_time + duration / 2,
|
222 |
+
y=y_pos,
|
223 |
+
text=f"{task['batch']}",
|
224 |
+
showarrow=False,
|
225 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
226 |
+
)
|
227 |
+
)
|
228 |
+
|
229 |
# Prepare hover data (add traces in batches later)
|
230 |
hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
|
231 |
+
|
232 |
+
hover_traces.append(
|
233 |
+
dict(
|
234 |
+
x=[start_time + duration / 2],
|
235 |
+
y=[y_pos],
|
236 |
+
mode="markers",
|
237 |
+
marker=dict(opacity=0), # Invisible marker
|
238 |
+
hoverinfo="text",
|
239 |
+
text=hover_text,
|
240 |
+
showlegend=False,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
|
244 |
# Update progress
|
245 |
if show_progress:
|
246 |
tasks_processed += 1
|
|
|
248 |
|
249 |
# Add all shapes at once for better performance
|
250 |
fig.update_layout(shapes=shapes)
|
251 |
+
|
252 |
# Add all annotations at once
|
253 |
fig.update_layout(annotations=annotations)
|
254 |
+
|
255 |
# Add all hover traces at once
|
256 |
for trace in hover_traces:
|
257 |
fig.add_trace(go.Scatter(**trace))
|
258 |
|
259 |
# Add custom legend
|
260 |
legend_items = []
|
261 |
+
|
262 |
# Find the maximum virtual stage in the data
|
263 |
max_virtual_stage = 0
|
264 |
for device in schedule_data:
|
265 |
for task in schedule_data[device]:
|
266 |
virtual_stage = task["stage"] // num_devices
|
267 |
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
268 |
+
|
269 |
# Add forward and backward items for each virtual stage
|
270 |
for vs in range(max_virtual_stage + 1):
|
271 |
+
legend_items.append(
|
272 |
+
dict(
|
273 |
+
name=f"Forward (VS {vs})",
|
274 |
+
color=get_color("forward", vs * num_devices, num_devices),
|
275 |
+
)
|
276 |
+
)
|
277 |
+
legend_items.append(
|
278 |
+
dict(
|
279 |
+
name=f"Backward (VS {vs})",
|
280 |
+
color=get_color("backward", vs * num_devices, num_devices),
|
281 |
+
)
|
282 |
+
)
|
283 |
# Add entries for split backward operations if this is a zb1p schedule
|
284 |
+
if any(
|
285 |
+
task["type"] in ["backward_D", "backward_W"]
|
286 |
+
for device in schedule_data
|
287 |
+
for task in schedule_data[device]
|
288 |
+
):
|
289 |
+
legend_items.append(
|
290 |
+
dict(
|
291 |
+
name=f"Backward Grad (VS {vs})",
|
292 |
+
color=get_color("backward_D", vs * num_devices, num_devices),
|
293 |
+
)
|
294 |
+
)
|
295 |
+
legend_items.append(
|
296 |
+
dict(
|
297 |
+
name=f"Backward Weight (VS {vs})",
|
298 |
+
color=get_color("backward_W", vs * num_devices, num_devices),
|
299 |
+
)
|
300 |
+
)
|
301 |
+
|
302 |
# If no tasks found, add default legend items
|
303 |
if not legend_items:
|
304 |
legend_items = [
|
305 |
dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
|
306 |
dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
|
307 |
+
dict(
|
308 |
+
name="Backward Grad (VS 0)",
|
309 |
+
color=get_color("backward_D", 0, num_devices),
|
310 |
+
),
|
311 |
+
dict(
|
312 |
+
name="Backward Weight (VS 0)",
|
313 |
+
color=get_color("backward_W", 0, num_devices),
|
314 |
+
),
|
315 |
]
|
316 |
+
|
317 |
for i, item in enumerate(legend_items):
|
318 |
+
fig.add_trace(
|
319 |
+
go.Scatter(
|
320 |
+
x=[None],
|
321 |
+
y=[None],
|
322 |
+
mode="markers",
|
323 |
+
marker=dict(size=10, color=item["color"]),
|
324 |
+
name=item["name"],
|
325 |
+
showlegend=True,
|
326 |
+
)
|
327 |
+
)
|
328 |
if show_progress and i < len(legend_items) - 1:
|
329 |
progress_bar.update(1)
|
330 |
|
|
|
333 |
# Modify the ordering to put Device 1 at the top, then Device 0, then the rest
|
334 |
if num_devices >= 2:
|
335 |
# Move Device 1 to the top, followed by Device 0
|
336 |
+
device_labels = (
|
337 |
+
[device_labels[1], device_labels[0]] + device_labels[2:]
|
338 |
+
if num_devices > 1
|
339 |
+
else device_labels
|
340 |
+
)
|
341 |
+
|
342 |
# Calculate tick positions with no gaps
|
343 |
tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
|
344 |
+
|
345 |
# Adjust the range to ensure there are no empty spaces at the end
|
346 |
x_end = max_time * 1.05 # Add a small margin
|
347 |
|
|
|
361 |
text=title_text,
|
362 |
x=0.5,
|
363 |
y=0.98, # Move title position closer to the top
|
364 |
+
font=dict(size=20),
|
365 |
),
|
366 |
legend=dict(
|
367 |
orientation="v", # Changed from horizontal to vertical
|
368 |
yanchor="top",
|
369 |
y=1.02, # Position at the top
|
370 |
xanchor="right",
|
371 |
+
x=1.20, # Position further to the right to accommodate more items
|
372 |
title=dict(text="<b>Operation Types:</b>"),
|
373 |
itemsizing="constant",
|
374 |
+
tracegroupgap=0,
|
375 |
),
|
376 |
width=2000, # Increase width to accommodate the expanded legend
|
377 |
height=400, # Maintain current height
|
|
|
389 |
# Cache for storing processed schedule data
|
390 |
_schedule_data_cache = {}
|
391 |
|
392 |
+
|
393 |
+
def create_dash_app(
|
394 |
+
schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True
|
395 |
+
):
|
396 |
"""
|
397 |
Create a Dash app to visualize the pipeline schedule.
|
398 |
+
|
399 |
Args:
|
400 |
schedule: Schedule object to visualize
|
401 |
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
|
|
404 |
# Process schedule data only once and cache it
|
405 |
global _schedule_data_cache
|
406 |
cache_key = id(schedule)
|
407 |
+
|
408 |
if enable_caching and cache_key in _schedule_data_cache:
|
409 |
schedule_data = _schedule_data_cache[cache_key]
|
410 |
print("Using cached schedule data")
|
|
|
413 |
if enable_caching:
|
414 |
_schedule_data_cache[cache_key] = schedule_data
|
415 |
print("Cached schedule data")
|
416 |
+
|
417 |
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
|
418 |
print(f"Total tasks in schedule: {total_tasks}")
|
419 |
|
|
|
421 |
app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
|
422 |
|
423 |
# Create a more informative layout with data size information
|
424 |
+
app.layout = html.Div(
|
425 |
+
[
|
426 |
+
html.H1(
|
427 |
+
f"Pipeline Parallelism Visualization - {schedule_type}",
|
428 |
+
style={"textAlign": "center"},
|
429 |
+
),
|
430 |
+
html.Div(
|
431 |
+
[
|
432 |
+
html.P(
|
433 |
+
f"Number of devices: {len(schedule_data)}",
|
434 |
+
style={"display": "inline-block", "marginRight": "20px"},
|
435 |
+
),
|
436 |
+
html.P(
|
437 |
+
f"Total tasks: {total_tasks}",
|
438 |
+
style={"display": "inline-block", "marginRight": "20px"},
|
439 |
+
),
|
440 |
+
],
|
441 |
+
style={"marginBottom": "20px"},
|
442 |
+
),
|
443 |
+
html.Div(id="graph-container", children=[]),
|
444 |
+
dcc.Loading(
|
445 |
+
id="loading-graph",
|
446 |
+
type="circle",
|
447 |
+
children=[
|
448 |
+
dcc.Graph(
|
449 |
+
id="pipeline-graph",
|
450 |
+
config={
|
451 |
+
"displayModeBar": True,
|
452 |
+
"toImageButtonOptions": {
|
453 |
+
"format": "png",
|
454 |
+
"filename": "pipeline_visualization",
|
455 |
+
},
|
456 |
+
},
|
457 |
+
),
|
458 |
+
],
|
459 |
+
),
|
460 |
+
]
|
461 |
+
)
|
462 |
+
|
463 |
# Cache for storing figure to avoid regenerating it
|
464 |
figure_cache = {}
|
465 |
+
|
466 |
@app.callback(
|
467 |
Output("pipeline-graph", "figure"),
|
468 |
Input("graph-container", "children"),
|
|
|
474 |
if enable_caching and cache_key in figure_cache:
|
475 |
print("Using cached figure")
|
476 |
return figure_cache[cache_key]
|
477 |
+
|
478 |
# Create the figure
|
479 |
figure = create_pipeline_figure(schedule_data, show_progress=True)
|
480 |
+
|
481 |
# Cache the figure
|
482 |
if enable_caching:
|
483 |
figure_cache[cache_key] = figure
|
484 |
print("Cached figure")
|
485 |
+
|
486 |
return figure
|
487 |
|
488 |
return app
|
|
|
493 |
port: int = 8050,
|
494 |
debug: bool = False,
|
495 |
enable_caching: bool = True,
|
496 |
+
schedule_type="1f1b",
|
497 |
):
|
498 |
"""
|
499 |
Launch a Dash app to visualize the pipeline schedule interactively.
|
500 |
+
|
501 |
Args:
|
502 |
schedule: Schedule object to visualize
|
503 |
port: Port to run the Dash app on
|
|
|
505 |
enable_caching: Whether to cache schedule data and figures
|
506 |
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
507 |
"""
|
508 |
+
app = create_dash_app(
|
509 |
+
schedule, schedule_type=schedule_type, enable_caching=enable_caching
|
510 |
+
)
|
511 |
print(f"Starting Dash app on http://localhost:{port}/")
|
512 |
app.run_server(debug=debug, port=port)
|