Victarry commited on
Commit
ec19476
·
0 Parent(s):

Initial commit: 1F1B PP schedule visualization.

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -0
  2. .gitignore +78 -0
  3. README.md +77 -0
  4. configs/standard.json +8 -0
  5. pipeline.py +477 -0
  6. visualizer.py +97 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ assets/*.png filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .env
28
+
29
+ # IDE specific files
30
+ .idea/
31
+ .vscode/
32
+ *.swp
33
+ *.swo
34
+ .DS_Store
35
+
36
+ # Jupyter Notebook
37
+ .ipynb_checkpoints
38
+
39
+ # Distribution / packaging
40
+ .Python
41
+ env/
42
+ build/
43
+ develop-eggs/
44
+ dist/
45
+ downloads/
46
+ eggs/
47
+ .eggs/
48
+ lib/
49
+ lib64/
50
+ parts/
51
+ sdist/
52
+ var/
53
+ wheels/
54
+ *.egg-info/
55
+ .installed.cfg
56
+ *.egg
57
+
58
+ # Unit test / coverage reports
59
+ htmlcov/
60
+ .tox/
61
+ .coverage
62
+ .coverage.*
63
+ .cache
64
+ nosetests.xml
65
+ coverage.xml
66
+ *.cover
67
+ .hypothesis/
68
+
69
+ # Pipeline visualization outputs
70
+ *.png
71
+ *.jpg
72
+ *.jpeg
73
+ *.pdf
74
+ *.svg
75
+
76
+ # Local configuration
77
+ config.ini
78
+ secrets.json
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pipeline Parallelism Scheduler and Visualizer
2
+
3
+ This tool simulates and visualizes pipeline parallelism scheduling strategies, focusing on the 1F1B (One-Forward-One-Backward) scheduling algorithm commonly used in distributed deep learning.
4
+
5
+ ## Usage
6
+
7
+ ### Example Output
8
+
9
+ ```bash
10
+ python pipeline.py --num-stages 4 --num-batches 8
11
+ ```
12
+ ![Example 1F1B schedule](pipeline_1f1b.png)
13
+
14
+ ### Command Line Interface
15
+
16
+ | Option | Short | Description |
17
+ |--------|-------|-------------|
18
+ | `--config` | `-c` | Path to config file (JSON or YAML) |
19
+ | `--num-stages` | `-s` | Number of pipeline stages (devices) |
20
+ | `--num-batches` | `-b` | Number of micro-batches |
21
+ | `--forward-times` | `-f` | Time for forward pass at each stage (space-separated list) |
22
+ | `--backward-times` | `-bw` | Time for backward pass at each stage (space-separated list) |
23
+ | `--output` | `-o` | Output file path for visualization |
24
+ | `--no-visualization` | | Skip visualization generation |
25
+ | `--p2p-time`| | P2P communication time of PP |
26
+
27
+ ### Using Configuration Files
28
+
29
+ You can use either JSON or YAML configuration files:
30
+
31
+ Example JSON configuration (sample_config.json):
32
+ ```json
33
+ {
34
+ "num_stages": 6,
35
+ "num_batches": 12,
36
+ "forward_times": [0.8, 1.0, 1.2, 1.0, 0.9, 1.1],
37
+ "backward_times": [1.6, 2.0, 2.4, 2.0, 1.8, 2.2],
38
+ "output_file": "pipeline_1f1b_custom.png"
39
+ }
40
+ ```
41
+
42
+ Example YAML configuration (sample_config.yaml):
43
+ ```yaml
44
+ # Pipeline Parallelism Configuration
45
+ num_stages: 5
46
+ num_batches: 8
47
+ forward_times:
48
+ - 0.9
49
+ - 1.1
50
+ - 1.0
51
+ - 0.8
52
+ - 1.2
53
+ backward_times:
54
+ - 1.8
55
+ - 2.2
56
+ - 2.0
57
+ - 1.6
58
+ - 2.4
59
+ output_file: "pipeline_1f1b_yaml.png"
60
+ ```
61
+
62
+ ## About Pipeline Parallelism
63
+
64
+ Pipeline parallelism is a distributed deep learning training strategy that splits model layers across multiple devices. Each device processes a different stage of the neural network, creating a pipeline where multiple micro-batches can be processed simultaneously.
65
+
66
+ The 1F1B (One-Forward-One-Backward) scheduling algorithm is an efficient strategy for pipeline parallelism that balances throughput with memory usage. It follows these phases:
67
+ 1. **Warmup Phase**: Forward passes for the first several micro-batches
68
+ 2. **Steady State**: Each device alternates between forward and backward passes
69
+ 3. **Cooldown Phase**: Backward passes to complete the computation for remaining micro-batches
70
+
71
+ The "bubble rate" metric measures the inefficiency in the pipeline, representing the percentage of time devices spend idle waiting for dependencies.
72
+
73
+ ## References
74
+
75
+ - PipeDream: Generalized Pipeline Parallelism for DNN Training (SOSP'19)
76
+ - GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism (NeurIPS'19)
77
+ - Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
configs/standard.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_stages": 4,
3
+ "num_batches": 8,
4
+ "forward_times": [1.0, 1.0, 1.0, 1.0],
5
+ "backward_times": [2.0, 2.0, 2.0, 2.0],
6
+ "output_file": "pipeline_1f1b.png",
7
+ "p2p_time": 0.0
8
+ }
pipeline.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import argparse
4
+ import json
5
+ import yaml
6
+ import os
7
+ from matplotlib.patches import Rectangle
8
+ from typing import List, Tuple, Dict, Literal
9
+
10
+ # Import visualization function from the new module
11
+ from visualizer import visualize_pipeline_parallelism
12
+
13
+
14
+ def create_1f1b_schedule(
15
+ num_stages: int,
16
+ num_batches: int,
17
+ forward_times: List[float],
18
+ backward_times: List[float],
19
+ p2p_time: float = 0.0,
20
+ ) -> Dict[int, List[Dict]]:
21
+ """
22
+ Create a 1F1B (One-Forward-One-Backward) schedule for pipeline parallelism.
23
+
24
+ This implementation takes a data-centric approach:
25
+ 1. First determine the operation sequence for each pipeline stage (which microbatch to process when)
26
+ 2. Then calculate timing based on dependencies between operations
27
+
28
+ The 1F1B pattern has three phases:
29
+ - Warmup: Forward passes for first num_stages microbatches
30
+ - Steady state: Alternating between forward and backward passes
31
+ - Cooldown: Backward passes for remaining microbatches
32
+
33
+ Returns:
34
+ A dictionary mapping device IDs to lists of tasks.
35
+ Each task is a dictionary with keys:
36
+ - 'type': 'forward' or 'backward'
37
+ - 'batch': batch number
38
+ - 'start_time': start time of the task
39
+ - 'duration': duration of the task
40
+ """
41
+ # Initialize empty schedule
42
+ schedule = {stage: [] for stage in range(num_stages)}
43
+
44
+ # Step 1: Determine operation sequence for each stage
45
+ # This will generate the sequence of operations (forward/backward on which microbatch)
46
+ # that each stage should perform, without timing information yet
47
+ operation_sequence = determine_1f1b_operation_sequence(num_stages, num_batches)
48
+
49
+ # Step 2: Convert operation sequence to schedule with timing
50
+ # Taking into account dependencies between operations
51
+ schedule = calculate_operation_timing(
52
+ operation_sequence, num_stages, forward_times, backward_times, p2p_time
53
+ )
54
+
55
+ return schedule
56
+
57
+
58
+ def determine_1f1b_operation_sequence(
59
+ num_stages: int, num_batches: int
60
+ ) -> Dict[int, List[Dict]]:
61
+ """
62
+ Determine the sequence of operations (forward/backward) for each stage in 1F1B scheduling.
63
+
64
+ Args:
65
+ num_stages: Number of pipeline stages
66
+ num_batches: Number of micro-batches
67
+
68
+ Returns:
69
+ Dictionary mapping stage ID to a list of operations in sequence.
70
+ Each operation is a dict with keys 'type' ('forward' or 'backward') and 'batch'.
71
+ """
72
+ operation_sequence = {i: [] for i in range(num_stages)}
73
+ for current_stage in range(num_stages):
74
+ warmup_batches = num_stages - current_stage
75
+ for j in range(1, warmup_batches + 1):
76
+ operation_sequence[current_stage].append({"type": "forward", "batch": j})
77
+ steady_batches = num_batches - warmup_batches
78
+ for j in range(warmup_batches + 1, warmup_batches + steady_batches + 1):
79
+ operation_sequence[current_stage].append(
80
+ {"type": "backward", "batch": j - warmup_batches}
81
+ )
82
+ operation_sequence[current_stage].append({"type": "forward", "batch": j})
83
+ for j in range(warmup_batches):
84
+ operation_sequence[current_stage].append(
85
+ {"type": "backward", "batch": j + steady_batches + 1}
86
+ )
87
+
88
+ return operation_sequence
89
+
90
+
91
+ def calculate_operation_timing(
92
+ operation_sequence: Dict[int, List[Dict]],
93
+ num_stages: int,
94
+ forward_times: List[float],
95
+ backward_times: List[float],
96
+ p2p_time: float = 0.0,
97
+ ) -> Dict[int, List[Dict]]:
98
+ """
99
+ Recursively calculate the specific timing of each operation in a 1F1B schedule.
100
+
101
+ When encountering an operation that depends on a previous operation that hasn't been calculated yet,
102
+ it will recursively calculate the timing of those operations.
103
+
104
+ Args:
105
+ operation_sequence: Operation sequence for each stage
106
+ num_stages: Number of pipeline stages
107
+ forward_times: Forward propagation time for each stage
108
+ backward_times: Backward propagation time for each stage
109
+ p2p_time: Point-to-point communication time between stages
110
+
111
+ Returns:
112
+ Complete schedule with timing information, each operation includes start_time and duration
113
+ """
114
+ # Initialize schedule with timing information
115
+ schedule = {i: [] for i in range(num_stages)}
116
+
117
+ # For recording already computed operation end times
118
+ # Format: {(stage, batch, op_type): (start_time, end_time)}
119
+ computed_ops = {}
120
+
121
+ # For recording the end time of the last operation for each stage
122
+ stage_last_end_time = [0.0] * num_stages
123
+
124
+ # Helper function: recursively calculate the time for an operation
125
+ def compute_op_time(stage, batch, op_type):
126
+ # Check if this operation has already been calculated
127
+ key = (stage, batch, op_type)
128
+ if key in computed_ops:
129
+ return computed_ops[key]
130
+
131
+ # Get operation duration
132
+ duration = (
133
+ forward_times[stage] if op_type == "forward" else backward_times[stage]
134
+ )
135
+
136
+ # Determine start time (dependent on other operations)
137
+ # 1. Consider sequential dependencies on the stage (must wait for previous operation to complete)
138
+ start_time = stage_last_end_time[stage]
139
+
140
+ # 2. Forward pass also depends on forward pass of previous stage (if not the first stage)
141
+ if op_type == "forward" and stage > 0:
142
+ # Recursively calculate the time for the forward pass of the previous stage (if not calculated yet)
143
+ prev_stage_key = (stage - 1, batch, "forward")
144
+ if prev_stage_key not in computed_ops:
145
+ prev_start, prev_end = compute_op_time(stage - 1, batch, "forward")
146
+ else:
147
+ _, prev_end = computed_ops[prev_stage_key]
148
+ # Update start time
149
+ start_time = max(start_time, prev_end + p2p_time)
150
+
151
+ # 3. Backward pass depends on:
152
+ elif op_type == "backward":
153
+ # a. Forward pass of the same stage
154
+ same_stage_forward_key = (stage, batch, "forward")
155
+ if same_stage_forward_key not in computed_ops:
156
+ _, forward_end = compute_op_time(stage, batch, "forward")
157
+ else:
158
+ _, forward_end = computed_ops[same_stage_forward_key]
159
+
160
+ start_time = max(start_time, forward_end)
161
+
162
+ # b. Backward pass of the next stage (if not the last stage)
163
+ if stage < num_stages - 1:
164
+ next_stage_backward_key = (stage + 1, batch, "backward")
165
+ if next_stage_backward_key not in computed_ops:
166
+ _, next_backward_end = compute_op_time(stage + 1, batch, "backward")
167
+ else:
168
+ _, next_backward_end = computed_ops[next_stage_backward_key]
169
+
170
+ start_time = max(start_time, next_backward_end + p2p_time)
171
+
172
+ # Calculate end time
173
+ end_time = start_time + duration
174
+
175
+ # Store calculation results
176
+ computed_ops[key] = (start_time, end_time)
177
+
178
+ # Update the end time of the last operation for this stage
179
+ stage_last_end_time[stage] = end_time
180
+
181
+ return start_time, end_time
182
+
183
+ # Calculate time for each operation in the operation_sequence
184
+ for i in range(len(operation_sequence[0])):
185
+ for stage in range(num_stages):
186
+ batch = operation_sequence[stage][i]["batch"]
187
+ op_type = operation_sequence[stage][i]["type"]
188
+
189
+ # Recursively calculate the time for this operation
190
+ start_time, _ = compute_op_time(stage, batch, op_type)
191
+
192
+ # Fill in scheduling information
193
+ op_with_timing = operation_sequence[stage][i].copy()
194
+ op_with_timing["start_time"] = start_time
195
+ op_with_timing["duration"] = (
196
+ forward_times[stage] if op_type == "forward" else backward_times[stage]
197
+ )
198
+ schedule[stage].append(op_with_timing)
199
+
200
+ return schedule
201
+
202
+
203
+ def get_bubble_rate(schedule: Dict[int, List[Dict]]):
204
+ num_stages = len(schedule)
205
+
206
+ max_time = 0
207
+ for device in schedule:
208
+ for task in schedule[device]:
209
+ end_time = task["start_time"] + task["duration"]
210
+ if end_time > max_time:
211
+ max_time = end_time
212
+
213
+ total_execution_time = max_time * num_stages
214
+
215
+ total_computation_time = 0
216
+ device_computation_times = {}
217
+
218
+ for device in schedule:
219
+ device_computation_time = 0
220
+ for task in schedule[device]:
221
+ device_computation_time += task["duration"]
222
+ device_computation_times[device] = device_computation_time
223
+ total_computation_time += device_computation_time
224
+
225
+ bubble_rate = (
226
+ total_execution_time - total_computation_time
227
+ ) / total_computation_time
228
+ return bubble_rate
229
+
230
+
231
+ def read_config_file(config_path):
232
+ """
233
+ Read configuration from a JSON or YAML file.
234
+
235
+ Args:
236
+ config_path: Path to the config file (JSON or YAML)
237
+
238
+ Returns:
239
+ Dictionary containing configuration parameters
240
+ """
241
+ if not os.path.exists(config_path):
242
+ raise FileNotFoundError(f"Config file not found: {config_path}")
243
+
244
+ file_ext = os.path.splitext(config_path)[1].lower()
245
+
246
+ try:
247
+ with open(config_path, "r") as f:
248
+ if file_ext == ".json":
249
+ config = json.load(f)
250
+ elif file_ext in (".yaml", ".yml"):
251
+ config = yaml.safe_load(f)
252
+ else:
253
+ raise ValueError(
254
+ f"Unsupported config file format: {file_ext}. Use .json, .yaml, or .yml"
255
+ )
256
+ return config
257
+ except Exception as e:
258
+ raise ValueError(f"Error reading config file: {str(e)}")
259
+
260
+
261
+ def parse_args():
262
+ """
263
+ Parse command-line arguments for the pipeline parallelism tool.
264
+
265
+ Returns:
266
+ Parsed arguments namespace
267
+ """
268
+ parser = argparse.ArgumentParser(
269
+ description="Pipeline Parallelism Scheduler and Visualizer",
270
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
271
+ )
272
+
273
+ # Config file option
274
+ parser.add_argument(
275
+ "--config", "-c", type=str, help="Path to config file (JSON or YAML)"
276
+ )
277
+
278
+ # Main parameters
279
+ parser.add_argument(
280
+ "--num-stages",
281
+ "-s",
282
+ type=int,
283
+ default=4,
284
+ help="Number of pipeline stages (devices)",
285
+ )
286
+
287
+ parser.add_argument(
288
+ "--num-batches", "-b", type=int, default=10, help="Number of micro-batches"
289
+ )
290
+
291
+ # Forward and backward times
292
+ parser.add_argument(
293
+ "--forward-times",
294
+ "-f",
295
+ type=float,
296
+ nargs="+",
297
+ help="Time for forward pass at each stage (space-separated list)",
298
+ )
299
+
300
+ parser.add_argument(
301
+ "--backward-times",
302
+ "-bw",
303
+ type=float,
304
+ nargs="+",
305
+ help="Time for backward pass at each stage (space-separated list)",
306
+ )
307
+
308
+ # Output options
309
+ parser.add_argument(
310
+ "--output",
311
+ "-o",
312
+ type=str,
313
+ default="pipeline_1f1b.png",
314
+ help="Output file path for visualization",
315
+ )
316
+
317
+ parser.add_argument(
318
+ "--no-visualization", action="store_true", help="Skip visualization generation"
319
+ )
320
+
321
+ parser.add_argument(
322
+ "--p2p-time",
323
+ type=float,
324
+ default=0.0,
325
+ help="Time for point-to-point communication between stages",
326
+ )
327
+
328
+ return parser.parse_args()
329
+
330
+
331
+ def example_usage():
332
+ """Example usage of the visualization function and testing the scheduling algorithms."""
333
+ # Example parameters
334
+ num_stages = 4 # Number of pipeline stages (devices)
335
+ num_batches = 10 # Number of micro-batches
336
+
337
+ # Example times for forward and backward passes for each stage
338
+ forward_times = [1.0, 1.0, 1.0, 1.0] # Time for forward pass at each stage
339
+ backward_times = [2.0, 2.0, 2.0, 2.0] # Time for backward pass at each stage
340
+
341
+ # Create 1F1B schedule
342
+ schedule = create_1f1b_schedule(
343
+ num_stages=num_stages,
344
+ num_batches=num_batches,
345
+ forward_times=forward_times,
346
+ backward_times=backward_times,
347
+ )
348
+
349
+ # Create visualization with the schedule
350
+ visualize_pipeline_parallelism(
351
+ schedule=schedule, schedule_type="1f1b", output_file="pipeline_1f1b.png"
352
+ )
353
+
354
+ # Analyze the schedule
355
+ bubble_rate = get_bubble_rate(schedule)
356
+ print(f"Bubble rate: {bubble_rate:.4f}")
357
+
358
+
359
+ def main():
360
+ """
361
+ Main function that parses arguments and runs the pipeline parallelism analysis.
362
+ """
363
+ args = parse_args()
364
+
365
+ # Initialize with default values
366
+ num_stages = 4
367
+ num_batches = 10
368
+ forward_times = None
369
+ backward_times = None
370
+ output_file = "pipeline_1f1b.png"
371
+ p2p_time = 0.0
372
+ # Read from config file if provided
373
+ if args.config:
374
+ try:
375
+ print(f"Reading configuration from {args.config}")
376
+ config = read_config_file(args.config)
377
+
378
+ # Update parameters from config
379
+ num_stages = config.get("num_stages", num_stages)
380
+ num_batches = config.get("num_batches", num_batches)
381
+ forward_times = config.get("forward_times")
382
+ backward_times = config.get("backward_times")
383
+ output_file = config.get("output_file", output_file)
384
+ p2p_time = config.get("p2p_time", 0.0)
385
+
386
+ except Exception as e:
387
+ print(f"Error reading config file: {str(e)}")
388
+ print("Falling back to command line arguments or defaults")
389
+
390
+ # Command line arguments override config file
391
+ if args.num_stages:
392
+ num_stages = args.num_stages
393
+
394
+ if args.num_batches:
395
+ num_batches = args.num_batches
396
+
397
+ if args.forward_times:
398
+ forward_times = args.forward_times
399
+
400
+ if args.backward_times:
401
+ backward_times = args.backward_times
402
+
403
+ if args.output:
404
+ output_file = args.output
405
+
406
+ if args.p2p_time:
407
+ p2p_time = args.p2p_time
408
+
409
+ # Validate inputs
410
+ if forward_times is None:
411
+ forward_times = [1.0] * num_stages
412
+ elif len(forward_times) != num_stages:
413
+ print(
414
+ f"Warning: forward_times length ({len(forward_times)}) doesn't match num_stages ({num_stages})"
415
+ )
416
+ if len(forward_times) < num_stages:
417
+ # Extend with repeats of the last value
418
+ forward_times = list(forward_times) + [forward_times[-1]] * (
419
+ num_stages - len(forward_times)
420
+ )
421
+ else:
422
+ # Truncate
423
+ forward_times = forward_times[:num_stages]
424
+ print(f"Adjusted forward_times: {forward_times}")
425
+
426
+ if backward_times is None:
427
+ backward_times = [2.0] * num_stages
428
+ elif len(backward_times) != num_stages:
429
+ print(
430
+ f"Warning: backward_times length ({len(backward_times)}) doesn't match num_stages ({num_stages})"
431
+ )
432
+ if len(backward_times) < num_stages:
433
+ # Extend with repeats of the last value
434
+ backward_times = list(backward_times) + [backward_times[-1]] * (
435
+ num_stages - len(backward_times)
436
+ )
437
+ else:
438
+ # Truncate
439
+ backward_times = backward_times[:num_stages]
440
+ print(f"Adjusted backward_times: {backward_times}")
441
+
442
+ print(f"Running with parameters:")
443
+ print(f" num_stages: {num_stages}")
444
+ print(f" num_batches: {num_batches}")
445
+ print(f" forward_times: {forward_times}")
446
+ print(f" backward_times: {backward_times}")
447
+ print(f" output_file: {output_file}")
448
+
449
+ # Create 1F1B schedule
450
+ schedule = create_1f1b_schedule(
451
+ num_stages=num_stages,
452
+ num_batches=num_batches,
453
+ forward_times=forward_times,
454
+ backward_times=backward_times,
455
+ p2p_time=p2p_time,
456
+ )
457
+
458
+ # Create visualization unless --no-visualization is specified
459
+ if not args.no_visualization:
460
+ visualize_pipeline_parallelism(
461
+ schedule=schedule, schedule_type="1f1b", output_file=output_file
462
+ )
463
+
464
+ # Analyze the schedule
465
+ bubble_rate = get_bubble_rate(schedule)
466
+ print(f"Bubble rate: {bubble_rate:.4f}")
467
+
468
+ return {
469
+ "schedule": schedule,
470
+ "bubble_rate": bubble_rate,
471
+ "num_stages": num_stages,
472
+ "num_batches": num_batches,
473
+ }
474
+
475
+
476
+ if __name__ == "__main__":
477
+ main()
visualizer.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from matplotlib.patches import Rectangle
4
+ from typing import List, Dict, Literal
5
+
6
+
7
+ def visualize_pipeline_parallelism(
8
+ schedule: Dict[int, List[Dict]],
9
+ schedule_type: Literal["simple", "1f1b"] = "1f1b",
10
+ output_file: str = "pipeline_visualization.png",
11
+ ):
12
+ """
13
+ Visualize pipeline parallelism scheduling.
14
+
15
+ Args:
16
+ schedule: Dictionary mapping device IDs to lists of tasks.
17
+ Each task is a dictionary with keys:
18
+ - 'type': 'forward' or 'backward'
19
+ - 'batch': batch number
20
+ - 'start_time': start time of the task
21
+ - 'duration': duration of the task
22
+ schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
23
+ output_file: Path to save the visualization
24
+ """
25
+ # Colors for forward and backward passes
26
+ forward_color = "royalblue"
27
+ backward_color = "lightgreen"
28
+ empty_color = "lightgray"
29
+
30
+ # Find the number of stages (devices)
31
+ num_stages = len(schedule)
32
+
33
+ # Find the maximum time in the schedule
34
+ max_time = 0
35
+ for device in schedule:
36
+ for task in schedule[device]:
37
+ end_time = task["start_time"] + task["duration"]
38
+ if end_time > max_time:
39
+ max_time = end_time
40
+
41
+ # Create figure and axis
42
+ fig, ax = plt.subplots(figsize=(15, 5))
43
+
44
+ # Plot the schedule
45
+ for device_idx, device in enumerate(schedule):
46
+ for task in schedule[device]:
47
+ color = forward_color if task["type"] == "forward" else backward_color
48
+ rect = Rectangle(
49
+ (task["start_time"], device_idx),
50
+ task["duration"],
51
+ 0.8,
52
+ edgecolor="black",
53
+ facecolor=color,
54
+ alpha=0.8,
55
+ )
56
+ ax.add_patch(rect)
57
+
58
+ # Add text (batch number)
59
+ ax.text(
60
+ task["start_time"] + task["duration"] / 2,
61
+ device_idx + 0.4,
62
+ str(task["batch"]),
63
+ ha="center",
64
+ va="center",
65
+ fontsize=10,
66
+ fontweight="bold",
67
+ color="white" if task["type"] == "forward" else "black",
68
+ )
69
+
70
+ # Set axis limits and labels
71
+ ax.set_xlim(0, max_time * 1.05)
72
+ ax.set_ylim(-0.2, num_stages + 0.2)
73
+ ax.set_yticks(np.arange(num_stages) + 0.4)
74
+ ax.set_yticklabels([f"Device {i+1}" for i in range(num_stages)])
75
+ ax.set_xlabel("Time")
76
+ ax.set_title(f"Pipeline Parallelism Schedule ({schedule_type})")
77
+
78
+ # Add a legend
79
+ forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
80
+ backward_patch = Rectangle((0, 0), 1, 1, facecolor=backward_color)
81
+ ax.legend(
82
+ [forward_patch, backward_patch],
83
+ ["Forward Pass", "Backward Pass"],
84
+ loc="upper center",
85
+ bbox_to_anchor=(0.5, -0.15),
86
+ ncol=2,
87
+ )
88
+
89
+ # Add grid
90
+ ax.grid(True, linestyle="--", alpha=0.7)
91
+
92
+ # Save the figure
93
+ plt.tight_layout()
94
+ plt.savefig(output_file, dpi=300, bbox_inches="tight")
95
+ plt.close()
96
+
97
+ print(f"Visualization saved to {output_file}")