Victarry commited on
Commit
a49be3b
·
1 Parent(s): 7a4895e

Add VPP support and refactor project.

Browse files
.gitignore CHANGED
@@ -1,78 +1,10 @@
1
  # Python
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
- *.so
6
- .Python
7
- build/
8
- develop-eggs/
9
- dist/
10
- downloads/
11
- eggs/
12
- .eggs/
13
- lib/
14
- lib64/
15
- parts/
16
- sdist/
17
- var/
18
- wheels/
19
- *.egg-info/
20
- .installed.cfg
21
- *.egg
22
-
23
- # Virtual Environment
24
- venv/
25
- env/
26
- ENV/
27
- .env
28
-
29
- # IDE specific files
30
- .idea/
31
- .vscode/
32
- *.swp
33
- *.swo
34
- .DS_Store
35
-
36
- # Jupyter Notebook
37
- .ipynb_checkpoints
38
-
39
- # Distribution / packaging
40
- .Python
41
- env/
42
- build/
43
- develop-eggs/
44
- dist/
45
- downloads/
46
- eggs/
47
- .eggs/
48
- lib/
49
- lib64/
50
- parts/
51
- sdist/
52
- var/
53
- wheels/
54
- *.egg-info/
55
- .installed.cfg
56
- *.egg
57
-
58
- # Unit test / coverage reports
59
- htmlcov/
60
- .tox/
61
- .coverage
62
- .coverage.*
63
- .cache
64
- nosetests.xml
65
- coverage.xml
66
- *.cover
67
- .hypothesis/
68
-
69
- # Pipeline visualization outputs
70
- *.png
71
- *.jpg
72
- *.jpeg
73
- *.pdf
74
- *.svg
75
-
76
- # Local configuration
77
- config.ini
78
- secrets.json
 
1
  # Python
2
+ ./venv
3
+ uv.lock
4
+ outputs/
5
+
6
+ # Uncomment below if you want to include these files
7
+ # !assets/*.png
8
+ # !assets/*.jpg
9
+ # !docs/*.png
10
+ # !docs/*.jpg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README-dash-visualizer.md DELETED
@@ -1,91 +0,0 @@
1
- # Pipeline Parallelism Dash Visualizer
2
-
3
- This is an interactive Dash-based visualizer for pipeline parallelism scheduling, complementing the existing Matplotlib-based visualization.
4
-
5
- ## Features
6
-
7
- - **Static image generation** similar to the Matplotlib version
8
- - **Interactive web-based visualization** with Dash
9
- - **Download functionality** to save the visualization as PNG
10
- - **Progress indication** during figure creation and image generation
11
- - **Compatible API** with the existing visualizer
12
-
13
- ## Installation
14
-
15
- Install the required dependencies:
16
-
17
- ```bash
18
- pip install -r requirements-dash.txt
19
- ```
20
-
21
- ## Usage
22
-
23
- ### From Python
24
-
25
- ```python
26
- from pipeline import create_1f1b_schedule
27
- from dash_visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
28
-
29
- # Create a schedule
30
- schedule = create_1f1b_schedule(
31
- num_stages=4,
32
- num_batches=8,
33
- forward_times=[1.0, 1.0, 1.0, 1.0],
34
- backward_times=[2.0, 2.0, 2.0, 2.0],
35
- )
36
-
37
- # Generate a static image
38
- save_pipeline_visualization_plotly(
39
- schedule=schedule,
40
- schedule_type="1f1b",
41
- output_file="pipeline_plotly.png"
42
- )
43
-
44
- # OR launch an interactive Dash app
45
- visualize_pipeline_parallelism_dash(
46
- schedule=schedule,
47
- schedule_type="1f1b",
48
- port=8050,
49
- debug=False
50
- )
51
- ```
52
-
53
- ### Using the Command Line
54
-
55
- You can use the updated command line interface:
56
-
57
- ```bash
58
- # Generate a static image with Dash/Plotly
59
- python pipeline.py --visualizer dash --output-file pipeline_viz.png
60
-
61
- # Launch an interactive Dash app
62
- python pipeline.py --visualizer dash-interactive
63
-
64
- # Use the original Matplotlib visualizer
65
- python pipeline.py --visualizer matplotlib
66
- ```
67
-
68
- You can also use the dash_visualizer.py script directly for testing:
69
-
70
- ```bash
71
- # Generate a static image
72
- python dash_visualizer.py --output test_viz.png
73
-
74
- # Launch an interactive app
75
- python dash_visualizer.py --interactive
76
- ```
77
-
78
- ## Differences from Matplotlib Visualizer
79
-
80
- The Dash-based visualizer provides all the same visual elements as the Matplotlib version:
81
- - Color-coded rectangles for forward, backward, and optimizer operations
82
- - Batch numbers displayed inside each rectangle
83
- - Device labels on the y-axis
84
- - Clear legend
85
-
86
- Additional features:
87
- - Interactive web interface
88
- - Hovering over elements to see details
89
- - Download button to save the visualization
90
- - Progress bars for tracking visualization creation
91
- - Responsive layout that works well on different screen sizes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,77 +1,95 @@
1
- # Pipeline Parallelism Scheduler and Visualizer
2
 
3
- This tool simulates and visualizes pipeline parallelism scheduling strategies, focusing on the 1F1B (One-Forward-One-Backward) scheduling algorithm commonly used in distributed deep learning.
4
 
5
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- ### Example
 
 
 
 
 
 
8
 
 
9
  ```bash
10
- python pipeline.py --num-stages 4 --num-batches 8
11
  ```
12
- ![Example 1F1B schedule](pipeline_1f1b.png)
13
-
14
- ### Command Line Interface
15
-
16
- | Option | Short | Description |
17
- |--------|-------|-------------|
18
- | `--config` | `-c` | Path to config file (JSON or YAML) |
19
- | `--num-stages` | `-s` | Number of pipeline stages (devices) |
20
- | `--num-batches` | `-b` | Number of micro-batches |
21
- | `--forward-times` | `-f` | Time for forward pass at each stage (space-separated list) |
22
- | `--backward-times` | `-bw` | Time for backward pass at each stage (space-separated list) |
23
- | `--output` | `-o` | Output file path for visualization |
24
- | `--no-visualization` | | Skip visualization generation |
25
- | `--p2p-time`| | P2P communication time of PP |
26
-
27
- ### Using Configuration Files
28
-
29
- You can use either JSON or YAML configuration files:
30
-
31
- Example JSON configuration (sample_config.json):
32
- ```json
33
- {
34
- "num_stages": 6,
35
- "num_batches": 12,
36
- "forward_times": [0.8, 1.0, 1.2, 1.0, 0.9, 1.1],
37
- "backward_times": [1.6, 2.0, 2.4, 2.0, 1.8, 2.2],
38
- "output_file": "pipeline_1f1b_custom.png"
39
- }
40
  ```
41
 
42
- Example YAML configuration (sample_config.yaml):
43
- ```yaml
44
- # Pipeline Parallelism Configuration
45
- num_stages: 5
46
- num_batches: 8
47
- forward_times:
48
- - 0.9
49
- - 1.1
50
- - 1.0
51
- - 0.8
52
- - 1.2
53
- backward_times:
54
- - 1.8
55
- - 2.2
56
- - 2.0
57
- - 1.6
58
- - 2.4
59
- output_file: "pipeline_1f1b_yaml.png"
 
 
 
 
 
 
 
 
 
60
  ```
61
 
62
- ## About Pipeline Parallelism
63
 
64
- Pipeline parallelism is a distributed deep learning training strategy that splits model layers across multiple devices. Each device processes a different stage of the neural network, creating a pipeline where multiple micro-batches can be processed simultaneously.
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- The 1F1B (One-Forward-One-Backward) scheduling algorithm is an efficient strategy for pipeline parallelism that balances throughput with memory usage. It follows these phases:
67
- 1. **Warmup Phase**: Forward passes for the first several micro-batches
68
- 2. **Steady State**: Each device alternates between forward and backward passes
69
- 3. **Cooldown Phase**: Backward passes to complete the computation for remaining micro-batches
70
 
71
- The "bubble rate" metric measures the inefficiency in the pipeline, representing the percentage of time devices spend idle waiting for dependencies.
72
 
73
- ## References
74
 
75
- - PipeDream: Generalized Pipeline Parallelism for DNN Training (SOSP'19)
76
- - GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism (NeurIPS'19)
77
- - Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
 
1
+ # Pipeline Parallelism Emulation
2
 
3
+ This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
4
 
5
+ ## Overview
6
+
7
+ Pipeline parallelism is a technique used to train large models by partitioning the model across multiple devices and processing data in a pipelined fashion. This project allows you to:
8
+
9
+ - Simulate different pipeline parallelism strategies (1F1B, Interleaved)
10
+ - Visualize the execution schedule on multiple devices
11
+ - Compare different strategies for efficiency
12
+
13
+ ## Features
14
+ - Supported Pipeline Stragegies:
15
+ - 1F1B
16
+ - Interleaved 1F1B
17
+ - Visualization:
18
+ - Interactive visualization dashboard using Plotly/Dash
19
+ - Config:
20
+ - Configurable simulation parameters through Hydra
21
+ - Each stage
22
+
23
+ ## Installation
24
+
25
+ This project uses [uv](https://github.com/astral-sh/uv) for dependency management.
26
 
27
+ Setup `uv` if not installed in your computer:
28
+ ```
29
+ # On macOS and Linux.
30
+ curl -LsSf https://astral.sh/uv/install.sh | sh
31
+ ```
32
+
33
+ ## Usage
34
 
35
+ Running for 1F1B strategy:
36
  ```bash
37
+ uv run python main.py strategy=1f1b num_devices=4 num_stages=4 num_batches=8
38
  ```
39
+
40
+ ```bash
41
+ uv run python main.py strategy=interleave num_devices=4 num_stages=8 num_batches=8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ```
43
 
44
+ ## Configuration
45
+
46
+ The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
47
+
48
+ ### Using Different Configuration Files
49
+
50
+ You can use different configuration files with Hydra in several ways:
51
+
52
+ #### Recommended Approach
53
+
54
+ 1. Create multiple configuration files in the `conf` directory for different use cases:
55
+ ```
56
+ conf/
57
+ ├── config.yaml # Default configuration
58
+ └── model_A.yaml # Create your own config with stage-specific latency for performance projection.
59
+ ```
60
+
61
+ 2. Run with your desired configuration using the `--config-name` flag:
62
+ ```bash
63
+ uv run python main.py --config-name=model_A
64
+ ```
65
+
66
+ #### Override Specific Parameters
67
+
68
+ You can also override specific parameters at runtime:
69
+ ```bash
70
+ uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
71
  ```
72
 
73
+ ## Project Structure
74
 
75
+ ```
76
+ PP-Emulation/
77
+ ├── conf/ # Hydra configuration files
78
+ │ └── config.yaml # Default configuration
79
+ ├── src/ # Source code
80
+ │ ├── __init__.py # Package initialization
81
+ │ ├── execution_model.py # Schedule execution models
82
+ │ ├── strategies.py # Pipeline parallelism strategies
83
+ │ └── visualizer.py # Visualization utilities
84
+ ├── main.py # Main entry point
85
+ ├── pyproject.toml # Project metadata and dependencies
86
+ └── README.md # This file
87
+ ```
88
 
89
+ ## License
 
 
 
90
 
91
+ This project is licensed under the MIT License - see the LICENSE file for details.
92
 
93
+ ## Contributing
94
 
95
+ Contributions are welcome! Please feel free to submit a Pull Request.
 
 
conf/config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default configuration for Pipeline Parallelism Emulation
2
+ num_devices: 4
3
+ num_stages: 4
4
+ num_batches: 12
5
+ visualization_port: 8050
6
+ strategy: "1f1b" # Options: "1f1b", "interleave"
7
+ p2p_latency: 0.0
8
+
9
+ # Operation time configurations
10
+ op_times:
11
+ # Option 1: Simple configuration (same time for all stages)
12
+ forward: 1.0
13
+ backward: 2.0
14
+
15
+ # Option 2: Commented example of stage-specific configuration
16
+ # forward:
17
+ # 0: 0.8 # Stage 0 forward time
18
+ # 1: 1.2 # Stage 1 forward time
19
+ # 2: 1.5 # Stage 2 forward time
20
+ # 3: 0.9 # Stage 3 forward time
21
+ # backward:
22
+ # 0: 1.0 # Stage 0 backward time
configs/standard.json DELETED
@@ -1,8 +0,0 @@
1
- {
2
- "num_stages": 4,
3
- "num_batches": 8,
4
- "forward_times": [1.0, 1.0, 1.0, 1.0],
5
- "backward_times": [2.0, 2.0, 2.0, 2.0],
6
- "output_file": "pipeline_1f1b.png",
7
- "p2p_time": 0.0
8
- }
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.execution_model import ScheduleConfig, ScheduleExecutor
2
+ from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
3
+ from src.visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
4
+ import hydra
5
+ from omegaconf import DictConfig, OmegaConf
6
+
7
+
8
+ @hydra.main(config_path="conf", config_name="config", version_base=None)
9
+ def main(cfg: DictConfig) -> None:
10
+ """Run pipeline parallelism simulation with the specified configuration."""
11
+ print(f"Running with configuration: {cfg}")
12
+
13
+ if cfg.strategy == "1f1b":
14
+ run_1f1b(cfg)
15
+ elif cfg.strategy == "interleave":
16
+ run_interleave(cfg)
17
+ else:
18
+ raise ValueError(f"Unknown strategy: {cfg.strategy}")
19
+
20
+
21
+ def run_1f1b(cfg: DictConfig) -> None:
22
+ """Run 1F1B pipeline parallelism simulation."""
23
+ # Convert OmegaConf to dict for op_times if it exists
24
+ op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
25
+
26
+ schedule_config = ScheduleConfig(
27
+ num_devices=cfg.num_devices,
28
+ num_stages=cfg.num_stages,
29
+ num_batches=cfg.num_batches,
30
+ p2p_latency=cfg.p2p_latency,
31
+ op_times=op_times,
32
+ placement_strategy="1f1b"
33
+ )
34
+ schedule = generate_1f1b_schedule(schedule_config)
35
+ executor = ScheduleExecutor(schedule)
36
+ executor.execute()
37
+
38
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
39
+
40
+
41
+ def run_interleave(cfg: DictConfig) -> None:
42
+ """Run interleaved pipeline parallelism simulation."""
43
+ # Convert OmegaConf to dict for op_times if it exists
44
+ op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
45
+
46
+ schedule_config = ScheduleConfig(
47
+ num_devices=cfg.num_devices,
48
+ num_stages=cfg.num_stages,
49
+ num_batches=cfg.num_batches,
50
+ p2p_latency=cfg.p2p_latency,
51
+ placement_strategy="interleave",
52
+ op_times=op_times
53
+ )
54
+ schedule = generate_1f1b_interleave_schedule(schedule_config)
55
+ executor = ScheduleExecutor(schedule)
56
+ executor.execute()
57
+
58
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
pipeline.py DELETED
@@ -1,491 +0,0 @@
1
- import argparse
2
- import json
3
- import yaml
4
- import os
5
- from typing import List, Dict
6
-
7
- # Import visualization function from the new module
8
- from visualizer import visualize_pipeline_parallelism
9
- try:
10
- from dash_visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
11
- DASH_AVAILABLE = True
12
- except ImportError:
13
- DASH_AVAILABLE = False
14
-
15
-
16
- def create_1f1b_schedule(
17
- num_stages: int,
18
- num_batches: int,
19
- forward_times: List[float],
20
- backward_times: List[float],
21
- p2p_time: float = 0.0,
22
- ) -> Dict[int, List[Dict]]:
23
- """
24
- Create a 1F1B (One-Forward-One-Backward) schedule for pipeline parallelism.
25
-
26
- This implementation takes a data-centric approach:
27
- 1. First determine the operation sequence for each pipeline stage (which microbatch to process when)
28
- 2. Then calculate timing based on dependencies between operations
29
-
30
- The 1F1B pattern has three phases:
31
- - Warmup: Forward passes for first num_stages microbatches
32
- - Steady state: Alternating between forward and backward passes
33
- - Cooldown: Backward passes for remaining microbatches
34
-
35
- Returns:
36
- A dictionary mapping device IDs to lists of tasks.
37
- Each task is a dictionary with keys:
38
- - 'type': 'forward' or 'backward'
39
- - 'batch': batch number
40
- - 'start_time': start time of the task
41
- - 'duration': duration of the task
42
- """
43
- # Initialize empty schedule
44
- schedule = {stage: [] for stage in range(num_stages)}
45
-
46
- # Step 1: Determine operation sequence for each stage
47
- # This will generate the sequence of operations (forward/backward on which microbatch)
48
- # that each stage should perform, without timing information yet
49
- operation_sequence = determine_1f1b_operation_sequence(num_stages, num_batches)
50
-
51
- # Step 2: Convert operation sequence to schedule with timing
52
- # Taking into account dependencies between operations
53
- schedule = calculate_operation_timing(
54
- operation_sequence, num_stages, forward_times, backward_times, p2p_time
55
- )
56
-
57
- return schedule
58
-
59
-
60
- def determine_1f1b_operation_sequence(
61
- num_stages: int, num_batches: int
62
- ) -> Dict[int, List[Dict]]:
63
- """
64
- Determine the sequence of operations (forward/backward) for each stage in 1F1B scheduling.
65
-
66
- Args:
67
- num_stages: Number of pipeline stages
68
- num_batches: Number of micro-batches
69
-
70
- Returns:
71
- Dictionary mapping stage ID to a list of operations in sequence.
72
- Each operation is a dict with keys 'type' ('forward' or 'backward') and 'batch'.
73
- """
74
- operation_sequence = {i: [] for i in range(num_stages)}
75
- for current_stage in range(num_stages):
76
- warmup_batches = num_stages - current_stage
77
- for j in range(1, warmup_batches + 1):
78
- operation_sequence[current_stage].append({"type": "forward", "batch": j})
79
- steady_batches = num_batches - warmup_batches
80
- for j in range(warmup_batches + 1, warmup_batches + steady_batches + 1):
81
- operation_sequence[current_stage].append(
82
- {"type": "backward", "batch": j - warmup_batches}
83
- )
84
- operation_sequence[current_stage].append({"type": "forward", "batch": j})
85
- for j in range(warmup_batches):
86
- operation_sequence[current_stage].append(
87
- {"type": "backward", "batch": j + steady_batches + 1}
88
- )
89
-
90
- return operation_sequence
91
-
92
-
93
- def calculate_operation_timing(
94
- operation_sequence: Dict[int, List[Dict]],
95
- num_stages: int,
96
- forward_times: List[float],
97
- backward_times: List[float],
98
- p2p_time: float = 0.0,
99
- ) -> Dict[int, List[Dict]]:
100
- """
101
- Recursively calculate the specific timing of each operation in a 1F1B schedule.
102
-
103
- When encountering an operation that depends on a previous operation that hasn't been calculated yet,
104
- it will recursively calculate the timing of those operations.
105
-
106
- Args:
107
- operation_sequence: Operation sequence for each stage
108
- num_stages: Number of pipeline stages
109
- forward_times: Forward propagation time for each stage
110
- backward_times: Backward propagation time for each stage
111
- p2p_time: Point-to-point communication time between stages
112
-
113
- Returns:
114
- Complete schedule with timing information, each operation includes start_time and duration
115
- """
116
- # Initialize schedule with timing information
117
- schedule = {i: [] for i in range(num_stages)}
118
-
119
- # For recording already computed operation end times
120
- # Format: {(stage, batch, op_type): (start_time, end_time)}
121
- computed_ops = {}
122
-
123
- # For recording the end time of the last operation for each stage
124
- stage_last_end_time = [0.0] * num_stages
125
-
126
- # Helper function: recursively calculate the time for an operation
127
- def compute_op_time(stage, batch, op_type):
128
- # Check if this operation has already been calculated
129
- key = (stage, batch, op_type)
130
- if key in computed_ops:
131
- return computed_ops[key]
132
-
133
- # Get operation duration
134
- duration = (
135
- forward_times[stage] if op_type == "forward" else backward_times[stage]
136
- )
137
-
138
- # Determine start time (dependent on other operations)
139
- # 1. Consider sequential dependencies on the stage (must wait for previous operation to complete)
140
- start_time = stage_last_end_time[stage]
141
-
142
- # 2. Forward pass also depends on forward pass of previous stage (if not the first stage)
143
- if op_type == "forward" and stage > 0:
144
- # Recursively calculate the time for the forward pass of the previous stage (if not calculated yet)
145
- prev_stage_key = (stage - 1, batch, "forward")
146
- if prev_stage_key not in computed_ops:
147
- prev_start, prev_end = compute_op_time(stage - 1, batch, "forward")
148
- else:
149
- _, prev_end = computed_ops[prev_stage_key]
150
- # Update start time
151
- start_time = max(start_time, prev_end + p2p_time)
152
-
153
- # 3. Backward pass depends on:
154
- elif op_type == "backward":
155
- # a. Forward pass of the same stage
156
- same_stage_forward_key = (stage, batch, "forward")
157
- if same_stage_forward_key not in computed_ops:
158
- _, forward_end = compute_op_time(stage, batch, "forward")
159
- else:
160
- _, forward_end = computed_ops[same_stage_forward_key]
161
-
162
- start_time = max(start_time, forward_end)
163
-
164
- # b. Backward pass of the next stage (if not the last stage)
165
- if stage < num_stages - 1:
166
- next_stage_backward_key = (stage + 1, batch, "backward")
167
- if next_stage_backward_key not in computed_ops:
168
- _, next_backward_end = compute_op_time(stage + 1, batch, "backward")
169
- else:
170
- _, next_backward_end = computed_ops[next_stage_backward_key]
171
-
172
- start_time = max(start_time, next_backward_end + p2p_time)
173
-
174
- # Calculate end time
175
- end_time = start_time + duration
176
-
177
- # Store calculation results
178
- computed_ops[key] = (start_time, end_time)
179
-
180
- # Update the end time of the last operation for this stage
181
- stage_last_end_time[stage] = end_time
182
-
183
- return start_time, end_time
184
-
185
- # Calculate time for each operation in the operation_sequence
186
- for i in range(len(operation_sequence[0])):
187
- for stage in range(num_stages):
188
- batch = operation_sequence[stage][i]["batch"]
189
- op_type = operation_sequence[stage][i]["type"]
190
-
191
- # Recursively calculate the time for this operation
192
- start_time, _ = compute_op_time(stage, batch, op_type)
193
-
194
- # Fill in scheduling information
195
- op_with_timing = operation_sequence[stage][i].copy()
196
- op_with_timing["start_time"] = start_time
197
- op_with_timing["duration"] = (
198
- forward_times[stage] if op_type == "forward" else backward_times[stage]
199
- )
200
- schedule[stage].append(op_with_timing)
201
-
202
- return schedule
203
-
204
-
205
- def get_schedule_info(schedule: Dict[int, List[Dict]]):
206
- num_stages = len(schedule)
207
-
208
- max_time = 0
209
- for device in schedule:
210
- for task in schedule[device]:
211
- end_time = task["start_time"] + task["duration"]
212
- if end_time > max_time:
213
- max_time = end_time
214
-
215
- total_execution_time = max_time * num_stages
216
-
217
- total_computation_time = 0
218
- device_computation_times = {}
219
-
220
- for device in schedule:
221
- device_computation_time = 0
222
- for task in schedule[device]:
223
- device_computation_time += task["duration"]
224
- device_computation_times[device] = device_computation_time
225
- total_computation_time += device_computation_time
226
-
227
- bubble_rate = (
228
- total_execution_time - total_computation_time
229
- ) / total_computation_time
230
-
231
- return {
232
- "bubble_rate": f"{bubble_rate*100:.2f}%",
233
- "execution_time": f"{max_time / 1000:.2f} s",
234
- }
235
-
236
-
237
- def read_config_file(config_path):
238
- """
239
- Read configuration from a JSON or YAML file.
240
-
241
- Args:
242
- config_path: Path to the config file (JSON or YAML)
243
-
244
- Returns:
245
- Dictionary containing configuration parameters
246
- """
247
- if not os.path.exists(config_path):
248
- raise FileNotFoundError(f"Config file not found: {config_path}")
249
-
250
- file_ext = os.path.splitext(config_path)[1].lower()
251
-
252
- try:
253
- with open(config_path, "r") as f:
254
- if file_ext == ".json":
255
- config = json.load(f)
256
- elif file_ext in (".yaml", ".yml"):
257
- config = yaml.safe_load(f)
258
- else:
259
- raise ValueError(
260
- f"Unsupported config file format: {file_ext}. Use .json, .yaml, or .yml"
261
- )
262
- return config
263
- except Exception as e:
264
- raise ValueError(f"Error reading config file: {str(e)}")
265
-
266
-
267
- def parse_args():
268
- """
269
- Parse command-line arguments for the pipeline parallelism tool.
270
-
271
- Returns:
272
- Parsed arguments namespace
273
- """
274
- parser = argparse.ArgumentParser(
275
- description="Pipeline Parallelism Scheduler and Visualizer",
276
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
277
- )
278
-
279
- # Config file option
280
- parser.add_argument(
281
- "--config", "-c", type=str, help="Path to config file (JSON or YAML)"
282
- )
283
-
284
- # Main parameters
285
- parser.add_argument(
286
- "--num-stages",
287
- "-s",
288
- type=int,
289
- default=0,
290
- help="Number of pipeline stages (devices)",
291
- )
292
-
293
- parser.add_argument(
294
- "--num-batches", "-b", type=int, default=0, help="Number of micro-batches"
295
- )
296
-
297
- # Forward and backward times
298
- parser.add_argument(
299
- "--forward-times",
300
- "-f",
301
- type=float,
302
- nargs="+",
303
- help="Time for forward pass at each stage (space-separated list)",
304
- )
305
-
306
- parser.add_argument(
307
- "--backward-times",
308
- "-bw",
309
- type=float,
310
- nargs="+",
311
- help="Time for backward pass at each stage (space-separated list)",
312
- )
313
-
314
- # Output options
315
- parser.add_argument(
316
- "--output",
317
- "-o",
318
- type=str,
319
- default="pipeline_1f1b.png",
320
- help="Output file path for visualization",
321
- )
322
-
323
- parser.add_argument(
324
- "--no-visualization", action="store_true", help="Skip visualization generation"
325
- )
326
-
327
- parser.add_argument(
328
- "--p2p-time",
329
- type=float,
330
- default=0.0,
331
- help="Time for point-to-point communication between stages",
332
- )
333
-
334
- parser.add_argument("--visualizer", choices=["matplotlib", "dash", "dash-interactive"],
335
- default="matplotlib", help="Visualization library to use")
336
-
337
- return parser.parse_args()
338
-
339
-
340
- def example_usage():
341
- """Example usage of the visualization function and testing the scheduling algorithms."""
342
- # Example parameters
343
- num_stages = 4 # Number of pipeline stages (devices)
344
- num_batches = 10 # Number of micro-batches
345
-
346
- # Example times for forward and backward passes for each stage
347
- forward_times = [1.0, 1.0, 1.0, 1.0] # Time for forward pass at each stage
348
- backward_times = [2.0, 2.0, 2.0, 2.0] # Time for backward pass at each stage
349
-
350
- # Create 1F1B schedule
351
- schedule = create_1f1b_schedule(
352
- num_stages=num_stages,
353
- num_batches=num_batches,
354
- forward_times=forward_times,
355
- backward_times=backward_times,
356
- )
357
-
358
- # Create visualization with the schedule
359
- visualize_pipeline_parallelism(
360
- schedule=schedule, schedule_type="1f1b", output_file="pipeline_1f1b.png"
361
- )
362
-
363
- # Analyze the schedule
364
- schedule_info = get_schedule_info(schedule)
365
- print(schedule_info)
366
-
367
-
368
- def main():
369
- """
370
- Main function that parses arguments and runs the pipeline parallelism analysis.
371
- """
372
- args = parse_args()
373
-
374
- # Initialize with default values
375
- num_stages = 4
376
- num_batches = 10
377
- forward_times = None
378
- backward_times = None
379
- output_file = "pipeline_1f1b.png"
380
- p2p_time = 0.0
381
-
382
- # Command line arguments override config file
383
- num_stages = args.num_stages
384
- num_batches = args.num_batches
385
- forward_times = args.forward_times
386
- backward_times = args.backward_times
387
- output_file = args.output
388
- p2p_time = args.p2p_time
389
-
390
- # Read from config file if provided
391
- if args.config:
392
- try:
393
- print(f"Reading configuration from {args.config}")
394
- config = read_config_file(args.config)
395
-
396
- # Update parameters from config
397
- num_stages = config.get("num_stages", num_stages)
398
- num_batches = config.get("num_batches", num_batches)
399
- forward_times = config.get("forward_times")
400
- backward_times = config.get("backward_times")
401
- output_file = config.get("output_file", output_file)
402
- p2p_time = config.get("p2p_time", 0.0)
403
-
404
- except Exception as e:
405
- print(f"Error reading config file: {str(e)}")
406
- print("Falling back to command line arguments or defaults")
407
-
408
- # Validate inputs
409
- if forward_times is None:
410
- forward_times = [1.0] * num_stages
411
- elif len(forward_times) != num_stages:
412
- print(
413
- f"Warning: forward_times length ({len(forward_times)}) doesn't match num_stages ({num_stages})"
414
- )
415
- if len(forward_times) < num_stages:
416
- # Extend with repeats of the last value
417
- forward_times = list(forward_times) + [forward_times[-1]] * (
418
- num_stages - len(forward_times)
419
- )
420
- else:
421
- # Truncate
422
- forward_times = forward_times[:num_stages]
423
- print(f"Adjusted forward_times: {forward_times}")
424
-
425
- if backward_times is None:
426
- backward_times = [2.0] * num_stages
427
- elif len(backward_times) != num_stages:
428
- print(
429
- f"Warning: backward_times length ({len(backward_times)}) doesn't match num_stages ({num_stages})"
430
- )
431
- if len(backward_times) < num_stages:
432
- # Extend with repeats of the last value
433
- backward_times = list(backward_times) + [backward_times[-1]] * (
434
- num_stages - len(backward_times)
435
- )
436
- else:
437
- # Truncate
438
- backward_times = backward_times[:num_stages]
439
- print(f"Adjusted backward_times: {backward_times}")
440
-
441
- print(f"Running with parameters:")
442
- print(f" num_stages: {num_stages}")
443
- print(f" num_batches: {num_batches}")
444
- print(f" forward_times: {forward_times}")
445
- print(f" backward_times: {backward_times}")
446
- print(f" output_file: {output_file}")
447
-
448
- # Create 1F1B schedule
449
- schedule = create_1f1b_schedule(
450
- num_stages=num_stages,
451
- num_batches=num_batches,
452
- forward_times=forward_times,
453
- backward_times=backward_times,
454
- p2p_time=p2p_time,
455
- )
456
-
457
- # Create visualization unless --no-visualization is specified
458
- if not args.no_visualization:
459
- if args.visualizer == "matplotlib" or not DASH_AVAILABLE:
460
- if not DASH_AVAILABLE and args.visualizer in ["dash", "dash-interactive"]:
461
- print("Warning: Dash not available. Falling back to matplotlib.")
462
- visualize_pipeline_parallelism(
463
- schedule=schedule, schedule_type="1f1b", output_file=output_file
464
- )
465
- elif args.visualizer == "dash":
466
- # Get output file name without extension to use the appropriate extension
467
- output_base = os.path.splitext(output_file)[0]
468
- output_dash = f"{output_base}_plotly.png"
469
- save_pipeline_visualization_plotly(
470
- schedule=schedule, schedule_type="1f1b", output_file=output_dash
471
- )
472
- elif args.visualizer == "dash-interactive":
473
- print("Using Dash interactive visualization")
474
- visualize_pipeline_parallelism_dash(
475
- schedule=schedule, schedule_type="1f1b", port=8050, debug=False
476
- )
477
-
478
- # Analyze the schedule
479
- schedule_info = get_schedule_info(schedule)
480
- print(schedule_info)
481
-
482
- return {
483
- "schedule": schedule,
484
- "schedule_info": schedule_info,
485
- "num_stages": num_stages,
486
- "num_batches": num_batches,
487
- }
488
-
489
-
490
- if __name__ == "__main__":
491
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_1f1b.png DELETED

Git LFS Details

  • SHA256: ff047349dfa8f855aca47e233be6a5b12b45441c7f45bbe69509d0602dc1a127
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
pyproject.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "pp-emulation"
7
+ version = "0.1.0"
8
+ description = "Pipeline Parallelism Emulation and Visualization"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ authors = [
12
+ {name = "Project Author"}
13
+ ]
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ dependencies = [
20
+ "dash>=2.14.0",
21
+ "hydra-core>=1.3.2",
22
+ "omegaconf>=2.3.0",
23
+ "plotly>=5.18.0",
24
+ "pandas>=2.1.0",
25
+ "numpy>=1.26.0",
26
+ "tqdm>=4.67.0",
27
+ ]
28
+
29
+ [project.optional-dependencies]
30
+ dev = [
31
+ "pytest>=7.4.0",
32
+ "black>=23.7.0",
33
+ "isort>=5.12.0",
34
+ "mypy>=1.5.1",
35
+ ]
36
+
37
+ # Add Hatch configuration to explicitly define where source code is located
38
+ [tool.hatch.build.targets.wheel]
39
+ packages = ["src"]
40
+
41
+ [tool.hatch.build.targets.sdist]
42
+ include = [
43
+ "src",
44
+ "main.py",
45
+ "conf",
46
+ "LICENSE",
47
+ "README.md",
48
+ ]
49
+
50
+ [tool.black]
51
+ line-length = 88
52
+ target-version = ["py310"]
53
+
54
+ [tool.isort]
55
+ profile = "black"
56
+ line_length = 88
57
+
58
+ [tool.mypy]
59
+ python_version = "3.10"
60
+ warn_return_any = true
61
+ warn_unused_configs = true
62
+ disallow_untyped_defs = true
63
+ disallow_incomplete_defs = true
64
+
65
+ [tool.pytest]
66
+ testpaths = ["tests"]
67
+ pythonpath = ["."]
requirements-dash.txt DELETED
@@ -1,5 +0,0 @@
1
- dash==2.13.0
2
- plotly==5.18.0
3
- numpy
4
- kaleido # For static image export
5
- tqdm # For progress bars
 
 
 
 
 
 
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Pipeline Parallelism Emulation and Visualization package."""
2
+
3
+ __version__ = "0.1.0"
src/execution_model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict, List, Optional, Union
3
+
4
+
5
+ class Operation:
6
+ """Operation is a single operation in the pipeline."""
7
+
8
+ def __init__(self, batch_id: int, stage_id: int, op_type: str):
9
+ self.batch_id = batch_id
10
+ self.stage_id = stage_id
11
+ self.op_type = op_type
12
+ self.device_id = None
13
+
14
+ self.start_time = None
15
+ self.end_time = None
16
+
17
+
18
+ class DeviceQueue:
19
+ def __init__(self, stages: List[int], device_id: int):
20
+ self.stages = stages
21
+ self.device_id = device_id
22
+ self.ops = [] # List of operations
23
+
24
+ def add_operation(self, op: Operation):
25
+ assert op.stage_id in self.stages
26
+ self.ops.append(op)
27
+ assert op.device_id is None
28
+ op.device_id = self.device_id
29
+
30
+
31
+ class ScheduleConfig:
32
+ def __init__(
33
+ self,
34
+ num_devices: int,
35
+ num_stages: int,
36
+ num_batches: int,
37
+ p2p_latency: float = 0.0,
38
+ placement_strategy: str = "normal",
39
+ op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
40
+ ):
41
+ self.num_devices = num_devices
42
+ self.num_stages = num_stages
43
+ self.num_batches = num_batches
44
+ self.p2p_latency = p2p_latency
45
+ self.placement_strategy = placement_strategy
46
+
47
+ # Initialize default operation times
48
+ self.op_times = {
49
+ "forward": 1.0,
50
+ "backward": 2.0,
51
+ }
52
+
53
+ # Update with user-provided operation times
54
+ if op_times:
55
+ for op_type, times in op_times.items():
56
+ if isinstance(times, dict):
57
+ # If a dict is provided, it maps stage_id -> time
58
+ if op_type not in self.op_times:
59
+ self.op_times[op_type] = {}
60
+ elif not isinstance(self.op_times[op_type], dict):
61
+ # Convert float to dict if needed
62
+ self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
63
+
64
+ # Update with provided stage-specific times
65
+ for stage_id, time in times.items():
66
+ if not isinstance(self.op_times[op_type], dict):
67
+ self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
68
+ self.op_times[op_type][stage_id] = time
69
+ else:
70
+ # If a float is provided, use same time for all stages
71
+ self.op_times[op_type] = times
72
+
73
+ assert num_stages % num_devices == 0, "num_stages must be divisible by num_devices"
74
+ self.num_stages_per_device = num_stages // num_devices
75
+
76
+ self.init_device_to_stages()
77
+ assert (
78
+ sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
79
+ )
80
+
81
+ def init_device_to_stages(self):
82
+ if self.placement_strategy == "normal":
83
+ # Evenly distributed
84
+ stages_per_device = self.num_stages // self.num_devices
85
+ self.device_to_stages = defaultdict(list)
86
+ for i in range(self.num_stages):
87
+ device_to_put = i // stages_per_device
88
+ self.device_to_stages[device_to_put].append(i)
89
+ elif self.placement_strategy == "interleave":
90
+ self.device_to_stages = defaultdict(list)
91
+ for i in range(self.num_stages):
92
+ device_to_put = i % self.num_devices
93
+ self.device_to_stages[device_to_put].append(i)
94
+ else:
95
+ raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
96
+
97
+ def get_op_time(self, op_type: str, stage_id: int):
98
+ if op_type not in self.op_times:
99
+ raise ValueError(f"Invalid operation type: {op_type}")
100
+
101
+ times = self.op_times[op_type]
102
+ if isinstance(times, dict):
103
+ # If we have stage-specific times, use those
104
+ if stage_id not in times:
105
+ raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
106
+ return times[stage_id]
107
+ else:
108
+ # If we have a single float, use the same value for all stages
109
+ return times
110
+
111
+
112
+ class Schedule:
113
+ def __init__(self, config: ScheduleConfig):
114
+ self.ops = {} # (batch_id, stage_id, op_type) -> Operation
115
+ self.dev_queues: List[DeviceQueue] = []
116
+ for dev_id in range(config.num_devices):
117
+ self.dev_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
118
+ self.config = config
119
+
120
+ self.init_operations()
121
+
122
+ def init_operations(self, op_types: Optional[List[str]] = None):
123
+ if op_types is None:
124
+ op_types = ["forward", "backward"]
125
+ for batch_id in range(self.config.num_batches):
126
+ for stage_id in range(self.config.num_stages):
127
+ for op_type in op_types:
128
+ self.ops[(batch_id, stage_id, op_type)] = Operation(
129
+ batch_id, stage_id, op_type
130
+ )
131
+
132
+ def get_op(self, batch_id: int, stage_id: int, op_type: str):
133
+ return self.ops[(batch_id, stage_id, op_type)]
134
+
135
+ def get_dependencies(self, op: Operation):
136
+ deps = []
137
+ if op.op_type == "forward":
138
+ if op.stage_id > 0:
139
+ deps.append(
140
+ (
141
+ self.get_op(op.batch_id, op.stage_id - 1, "forward"),
142
+ self.config.p2p_latency,
143
+ )
144
+ )
145
+ elif op.op_type == "backward":
146
+ if op.stage_id < self.config.num_stages - 1:
147
+ deps.append(
148
+ (
149
+ self.get_op(op.batch_id, op.stage_id + 1, "backward"),
150
+ self.config.p2p_latency,
151
+ )
152
+ )
153
+
154
+ device_index = self.dev_queues[op.device_id].ops.index(op)
155
+ if device_index > 0:
156
+ deps.append((self.dev_queues[op.device_id].ops[device_index - 1], 0.0))
157
+ return deps
158
+
159
+ def show(self):
160
+ """Display detailed information about the schedule for debugging purposes."""
161
+ print("\n=== SCHEDULE DETAILS ===")
162
+ print(f"Devices: {self.config.num_devices}, Stages: {self.config.num_stages}, Batches: {self.config.num_batches}")
163
+ print(f"Placement Strategy: {self.config.placement_strategy}")
164
+ print("\n=== DEVICE QUEUES ===")
165
+
166
+ for dev_id in range(self.config.num_devices):
167
+ print(f"\nDEVICE {dev_id} (Stages: {self.dev_queues[dev_id].stages}):")
168
+ print("-" * 80)
169
+ print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}")
170
+ print("-" * 80)
171
+
172
+ for op in self.dev_queues[dev_id].ops:
173
+ op_type = "Forward" if op.op_type == "forward" else "Backward"
174
+ start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A"
175
+ end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A"
176
+
177
+ duration = "N/A"
178
+ if op.start_time is not None and op.end_time is not None:
179
+ duration = f"{op.end_time - op.start_time:.2f}"
180
+
181
+ print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}")
182
+
183
+ # Find the total execution time (if timing info is available)
184
+ if all(op.end_time is not None for op in self.ops.values()):
185
+ total_time = max(op.end_time for op in self.ops.values())
186
+ print(f"\nTotal execution time: {total_time:.2f}")
187
+
188
+
189
+ class ScheduleExecutor:
190
+ def __init__(self, schedule: Schedule):
191
+ self.schedule = schedule
192
+
193
+ def execute(self):
194
+ def execute_op(op: Operation):
195
+ deps = self.schedule.get_dependencies(op)
196
+ if len(deps) == 0:
197
+ op.start_time = 0.0
198
+ else:
199
+ for dep, gap in deps:
200
+ if dep.end_time is None or dep.start_time is None:
201
+ execute_op(dep)
202
+ op.start_time = max(dep.end_time + gap for dep, gap in deps)
203
+ op.end_time = op.start_time + self.schedule.config.get_op_time(
204
+ op.op_type, op.stage_id
205
+ )
206
+
207
+ op_num = len(self.schedule.dev_queues[0].ops)
208
+ for i in range(op_num):
209
+ for dev_id in range(self.schedule.config.num_devices):
210
+ op = self.schedule.dev_queues[dev_id].ops[i]
211
+ execute_op(op)
212
+
213
+ for op in self.schedule.ops.values():
214
+ assert (
215
+ op.start_time is not None
216
+ ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time"
217
+ assert (
218
+ op.end_time is not None
219
+ ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time"
src/strategies.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from src.execution_model import Schedule, ScheduleConfig
3
+
4
+
5
+ def generate_1f1b_schedule(config: ScheduleConfig):
6
+ schedule = Schedule(config)
7
+
8
+ for i in range(config.num_devices):
9
+ fwd_batch_id = 0
10
+ bwd_batch_id = 0
11
+ cooldown_batches = warmup_batches = config.num_devices - i - 1
12
+ steady_batches = config.num_batches - warmup_batches
13
+
14
+ for _ in range(warmup_batches):
15
+ for j in range(len(schedule.dev_queues[i].stages)):
16
+ schedule.dev_queues[i].add_operation(
17
+ schedule.get_op(fwd_batch_id, schedule.dev_queues[i].stages[j], "forward")
18
+ )
19
+ fwd_batch_id += 1
20
+
21
+ for _ in range(steady_batches):
22
+ for j in range(len(schedule.dev_queues[i].stages)):
23
+ schedule.dev_queues[i].add_operation(
24
+ schedule.get_op(fwd_batch_id, schedule.dev_queues[i].stages[j], "forward")
25
+ )
26
+ fwd_batch_id += 1
27
+ for j in range(len(schedule.dev_queues[i].stages)-1, -1, -1):
28
+ schedule.dev_queues[i].add_operation(
29
+ schedule.get_op(bwd_batch_id, schedule.dev_queues[i].stages[j], "backward")
30
+ )
31
+ bwd_batch_id += 1
32
+
33
+ for _ in range(cooldown_batches):
34
+ for j in range(len(schedule.dev_queues[i].stages)-1, -1, -1):
35
+ schedule.dev_queues[i].add_operation(
36
+ schedule.get_op(bwd_batch_id, schedule.dev_queues[i].stages[j], "backward")
37
+ )
38
+ bwd_batch_id += 1
39
+
40
+ return schedule
41
+
42
+
43
+ # Some codes are copied from Megatron-LM
44
+ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
45
+ schedule = Schedule(config)
46
+
47
+ def get_pp_rank_microbatches(
48
+ num_microbatches,
49
+ num_devices,
50
+ device_id,
51
+ num_stages_per_device,
52
+ microbatch_group_size_per_vp_stage,
53
+ ):
54
+ """Get the number of total, warmup, and remaining microbatches in PP scheduling."""
55
+ total_num_microbatches = num_microbatches * num_stages_per_device
56
+ are_all_microbatches_in_warmup = False
57
+
58
+ if num_devices > 1:
59
+ if num_stages_per_device is None:
60
+ # forward_backward_pipelining_without_interleaving
61
+ num_warmup_microbatches = num_devices - device_id - 1
62
+ else:
63
+ # forward_backward_pipelining_with_interleaving
64
+ # Run (num_model_chunks-1)*microbatch_group_size_per_vp_stage on
65
+ # all workers, followed by more microbatches after depending on
66
+ # stage ID (more forward passes for earlier stages, later stages can
67
+ # immediately start with 1F1B).
68
+ num_warmup_microbatches = (num_devices - device_id - 1) * 2
69
+ num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
70
+ else:
71
+ # forward_backward_no_pipelining
72
+ num_warmup_microbatches = 1
73
+
74
+ if num_warmup_microbatches >= total_num_microbatches:
75
+ num_warmup_microbatches = total_num_microbatches
76
+ are_all_microbatches_in_warmup = True
77
+ num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
78
+
79
+ return (
80
+ total_num_microbatches,
81
+ are_all_microbatches_in_warmup,
82
+ num_warmup_microbatches,
83
+ num_microbatches_remaining,
84
+ )
85
+
86
+
87
+ def get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
88
+ """Get the schedule table for PP scheduling.
89
+
90
+ Create a tunable schedule lookup table.
91
+ The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
92
+ For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
93
+ virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
94
+ microbatch_id | 0 1 2 0 1 2 3 4 3 4
95
+ model_chunk_id | 0 0 0 1 1 1 0 0 1 1
96
+ """
97
+ schedule_table = []
98
+ for min_microbatch_id_in_group in range(
99
+ 0, num_microbatches, microbatch_group_size_per_vp_stage
100
+ ):
101
+ if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
102
+ # Construct schedule for the last microbatch group
103
+ schedule_table.extend(
104
+ [
105
+ (microbatch_id, model_chunk_id)
106
+ for model_chunk_id in range(num_model_chunks)
107
+ for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
108
+ ]
109
+ )
110
+ else:
111
+ # Construct schedule for other microbatch groups
112
+ schedule_table.extend(
113
+ [
114
+ (microbatch_id, model_chunk_id)
115
+ for model_chunk_id in range(num_model_chunks)
116
+ for microbatch_id in range(
117
+ min_microbatch_id_in_group,
118
+ min_microbatch_id_in_group + microbatch_group_size_per_vp_stage,
119
+ )
120
+ ]
121
+ )
122
+ return schedule_table
123
+
124
+
125
+ def convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
126
+ """Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
127
+ order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
128
+ virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
129
+ microbatch_id | 0 1 2 0 1 2 3 4 3 4
130
+ model_chunk_id | 0 0 0 1 1 1 0 0 1 1
131
+
132
+ Then the forward backward separated order is:
133
+ forward | 1 1 1 2 2 2 1 1 2 2
134
+ backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
135
+
136
+ If num_warmup_microbatches is 5, the output order is:
137
+ 1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
138
+ """
139
+ _, model_chunk_id_table = zip(*schedule_table)
140
+ forward_order = [chunk_id + 1 for chunk_id in model_chunk_id_table]
141
+ backward_order = [chunk_id - num_model_chunks for chunk_id in model_chunk_id_table]
142
+ order = forward_order[:num_warmup_microbatches]
143
+ for i in range(num_warmup_microbatches, len(forward_order)):
144
+ order.append(forward_order[i])
145
+ order.append(backward_order[i - num_warmup_microbatches])
146
+ if num_warmup_microbatches > 0:
147
+ order.extend(backward_order[-num_warmup_microbatches:])
148
+ return order
149
+
150
+ for device_id in range(config.num_devices):
151
+ microbatch_group_size_per_vp_stage = config.num_devices
152
+ total_num_microbatches, are_all_microbatches_in_warmup, num_warmup_microbatches, num_microbatches_remaining = get_pp_rank_microbatches(
153
+ config.num_batches,
154
+ config.num_devices,
155
+ device_id,
156
+ config.num_stages_per_device,
157
+ microbatch_group_size_per_vp_stage,
158
+ )
159
+
160
+ schedule_table = get_schedule_table(
161
+ config.num_batches,
162
+ config.num_stages_per_device,
163
+ microbatch_group_size_per_vp_stage,
164
+ )
165
+
166
+ order = convert_schedule_table_to_order(
167
+ num_warmup_microbatches,
168
+ num_model_chunks=config.num_stages_per_device,
169
+ schedule_table=schedule_table,
170
+ )
171
+
172
+ cur_stage_microbatch_id = {}
173
+ for i in range(1, config.num_stages_per_device+1):
174
+ cur_stage_microbatch_id[i] = 0
175
+ cur_stage_microbatch_id[-i] = 0
176
+ for order_item in order:
177
+ stage_id = schedule.dev_queues[device_id].stages[abs(order_item)-1]
178
+
179
+ if order_item > 0:
180
+ op_type = "forward"
181
+ micro_batch_id = cur_stage_microbatch_id[order_item]
182
+ cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
183
+ elif order_item < 0:
184
+ op_type = "backward"
185
+ micro_batch_id = cur_stage_microbatch_id[order_item]
186
+ cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
187
+ else:
188
+ raise ValueError(f"Invalid order item: {order_item}")
189
+ schedule.dev_queues[device_id].add_operation(
190
+ schedule.get_op(micro_batch_id, stage_id, op_type)
191
+ )
192
+ return schedule
dash_visualizer.py → src/visualizer.py RENAMED
@@ -1,41 +1,86 @@
1
  import dash
2
  from dash import dcc, html
3
- from dash.dependencies import Input, Output, State
4
  import plotly.graph_objects as go
5
- import numpy as np
6
- from typing import List, Dict, Literal
7
  from tqdm import tqdm
8
- import time
9
 
 
10
 
11
- def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_progress=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
  Create a Plotly figure for pipeline parallelism scheduling.
14
 
15
  Args:
16
- schedule: Dictionary mapping device IDs to lists of tasks.
17
- Each task is a dictionary with keys:
18
- - 'type': 'forward', 'backward', or 'optimizer'
19
- - 'batch': batch number
20
- - 'start_time': start time of the task
21
- - 'duration': duration of the task
22
  max_time: Optional maximum time to display
23
  show_progress: Whether to show a progress bar
24
  """
25
- # Colors for task types
26
- forward_color = "royalblue"
27
- backward_color = "sandybrown"
28
- optimizer_color = "#FFEFCF"
29
  empty_color = "whitesmoke"
 
 
 
 
 
 
 
30
 
31
- # Find the number of stages (devices)
32
- num_stages = len(schedule)
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Find the maximum time in the schedule if not provided
35
  if max_time is None:
36
  max_time = 0
37
- for device in schedule:
38
- for task in schedule[device]:
39
  end_time = task["start_time"] + task["duration"]
40
  if end_time > max_time:
41
  max_time = end_time
@@ -44,56 +89,51 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
44
  fig = go.Figure()
45
 
46
  # Initialize progress tracking
47
- total_tasks = sum(len(tasks) for tasks in schedule.values())
48
  tasks_processed = 0
49
 
50
  if show_progress:
51
- progress_bar = tqdm(total=total_tasks + num_stages + 3, desc="Creating visualization")
52
 
53
- # Add background for empty cells
54
- for device_idx in range(num_stages):
55
- device_idx_reversed = num_stages - device_idx - 1 # Reverse for plotting
56
- fig.add_trace(go.Scatter(
57
- x=[0, max_time],
58
- y=[device_idx_reversed, device_idx_reversed],
59
- mode='lines',
60
- line=dict(color='lightgray', width=0.5),
61
- showlegend=False,
62
- hoverinfo='none'
63
- ))
64
- if show_progress:
65
- progress_bar.update(1)
66
 
67
  # Add rectangles for each task
68
- for device_idx, device in enumerate(schedule):
69
- device_idx_reversed = num_stages - device_idx - 1
 
 
 
70
 
71
- for task in schedule[device]:
72
  # Determine task color and text color
73
  if task["type"] == "forward":
74
- color = forward_color
75
  text_color = "white"
76
  name = "Forward"
77
  elif task["type"] == "backward":
78
- color = backward_color
79
  text_color = "black"
80
  name = "Backward"
81
- else: # optimizer or any other type
82
- color = optimizer_color
83
  text_color = "black"
84
- name = "Optimizer step"
85
-
86
  # Add rectangle for the task
87
  start_time = task["start_time"]
88
  duration = task["duration"]
89
 
 
 
 
90
  # Create rectangle using shape
91
  fig.add_shape(
92
  type="rect",
93
  x0=start_time,
94
- y0=device_idx_reversed - 0.4,
95
  x1=start_time + duration,
96
- y1=device_idx_reversed + 0.4,
97
  line=dict(color="black", width=0.5),
98
  fillcolor=color,
99
  layer="above",
@@ -102,12 +142,23 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
102
  # Add batch number text
103
  fig.add_annotation(
104
  x=start_time + duration / 2,
105
- y=device_idx_reversed,
106
- text=str(task["batch"]),
107
  showarrow=False,
108
- font=dict(color=text_color, size=10, family="Arial, bold"),
109
  )
110
 
 
 
 
 
 
 
 
 
 
 
 
111
  # Update progress
112
  if show_progress:
113
  tasks_processed += 1
@@ -115,9 +166,8 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
115
 
116
  # Add custom legend
117
  legend_items = [
118
- dict(name="Forward", color=forward_color),
119
- dict(name="Backward", color=backward_color),
120
- dict(name="Optimizer step", color=optimizer_color)
121
  ]
122
 
123
  for i, item in enumerate(legend_items):
@@ -133,77 +183,98 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
133
  progress_bar.update(1)
134
 
135
  # Set axis properties
136
- device_labels = [f"Device {i+1}" for i in range(num_stages)]
137
- device_labels.reverse() # Reverse to put Device 1 at the top
 
 
 
138
 
 
 
 
139
  fig.update_layout(
140
- xaxis=dict(
141
- showticklabels=False,
142
- showgrid=False,
143
- zeroline=False,
144
- title="Time →",
145
- range=[0, max_time + 0.5]
146
- ),
147
  yaxis=dict(
148
  tickmode="array",
149
- tickvals=list(range(num_stages)),
150
  ticktext=device_labels,
151
  showgrid=False,
152
  zeroline=False,
153
- range=[-0.5, num_stages - 0.5]
154
  ),
155
- margin=dict(l=50, r=50, t=50, b=50),
156
  plot_bgcolor="white",
 
 
 
 
 
 
157
  legend=dict(
158
  orientation="h",
159
- yanchor="bottom",
160
- y=-0.2,
161
  xanchor="center",
162
  x=0.5
163
- )
 
 
 
 
164
  )
165
 
166
  if show_progress:
167
- progress_bar.update(1) # Final update for layout
168
  progress_bar.close()
169
 
170
  return fig
171
 
172
 
173
- def create_dash_app(schedule: Dict[int, List[Dict]], schedule_type="1f1b"):
174
  """
175
- Create a Dash app for interactive visualization of pipeline scheduling.
176
-
177
  Args:
178
- schedule: Dictionary mapping device IDs to lists of tasks
179
- schedule_type: Type of scheduling algorithm used
180
  """
181
- app = dash.Dash(__name__, title="Pipeline Parallelism Visualization")
 
 
 
 
182
 
183
  app.layout = html.Div([
184
- html.H1(f"Pipeline Parallelism Visualization ({schedule_type.upper()})",
185
- style={'textAlign': 'center'}),
186
-
187
- html.Div(id="loading-container", children=[
188
- dcc.Loading(
189
- id="loading-graph",
190
- type="circle",
191
- children=[
192
- html.Div(id="graph-container", children=[
193
- dcc.Graph(
194
- id='pipeline-graph',
195
- style={'height': '600px'}
196
- )
197
- ])
198
- ]
199
- )
200
- ]),
201
 
202
  html.Div([
203
- html.Button("Download PNG", id="btn-download",
204
- style={'margin': '10px'}),
205
- dcc.Download(id="download-image")
206
- ], style={'textAlign': 'center', 'marginTop': '20px'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  ])
208
 
209
  @app.callback(
@@ -213,98 +284,65 @@ def create_dash_app(schedule: Dict[int, List[Dict]], schedule_type="1f1b"):
213
  )
214
  def load_graph(_):
215
  # Create the figure when the app loads
216
- return create_pipeline_figure(schedule, show_progress=True)
217
-
218
  @app.callback(
219
  Output("download-image", "data"),
220
  Input("btn-download", "n_clicks"),
221
  prevent_initial_call=True,
222
  )
223
  def download_image(n_clicks):
224
- # Show progress in terminal for downloads
225
- fig = create_pipeline_figure(schedule, show_progress=True)
226
- img_bytes = fig.to_image(format="png", scale=3)
 
 
 
 
 
227
  return dict(
228
- content=img_bytes,
229
- filename="pipeline_visualization.png"
 
 
230
  )
231
 
232
  return app
233
 
234
 
235
  def visualize_pipeline_parallelism_dash(
236
- schedule: Dict[int, List[Dict]],
237
- schedule_type: Literal["simple", "1f1b"] = "1f1b",
238
  port: int = 8050,
239
  debug: bool = False
240
  ):
241
  """
242
- Create an interactive Dash visualization for pipeline parallelism scheduling.
243
-
244
  Args:
245
- schedule: Dictionary mapping device IDs to lists of tasks
246
- schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
247
- port: Port number to run the Dash app
248
- debug: Whether to run the app in debug mode
249
  """
250
- app = create_dash_app(schedule, schedule_type)
251
  print(f"Starting Dash app on http://localhost:{port}/")
252
  app.run_server(debug=debug, port=port)
253
 
254
 
255
  def save_pipeline_visualization_plotly(
256
- schedule: Dict[int, List[Dict]],
257
- schedule_type: Literal["simple", "1f1b"] = "1f1b",
258
  output_file: str = "pipeline_visualization_plotly.png",
259
  ):
260
  """
261
- Save a static Plotly visualization of pipeline parallelism scheduling.
262
-
263
  Args:
264
- schedule: Dictionary mapping device IDs to lists of tasks
265
- schedule_type: Type of scheduling algorithm used
266
- output_file: Path to save the visualization
267
  """
268
- print(f"Creating visualization for {len(schedule)} devices...")
269
- fig = create_pipeline_figure(schedule, show_progress=True)
270
-
271
- # Update layout for static image
272
- fig.update_layout(
273
- title=f"Pipeline Parallelism Visualization ({schedule_type.upper()})",
274
- title_x=0.5
275
- )
276
 
277
- print(f"Saving image to {output_file}...")
278
- # Save as image
279
- fig.write_image(output_file, scale=3)
280
  print(f"Visualization saved to {output_file}")
281
 
282
-
283
- if __name__ == "__main__":
284
- # Example usage
285
- import argparse
286
- from pipeline import create_1f1b_schedule
287
-
288
- parser = argparse.ArgumentParser(description="Pipeline Parallelism Visualizer")
289
- parser.add_argument("--num-stages", type=int, default=4, help="Number of pipeline stages")
290
- parser.add_argument("--num-batches", type=int, default=8, help="Number of microbatches")
291
- parser.add_argument("--interactive", action="store_true", help="Run interactive Dash app")
292
- parser.add_argument("--port", type=int, default=8050, help="Port for Dash app")
293
- parser.add_argument("--output", type=str, default="pipeline_visualization_plotly.png", help="Output file for static image")
294
- args = parser.parse_args()
295
-
296
- # Create an example schedule
297
- forward_times = [1.0] * args.num_stages
298
- backward_times = [2.0] * args.num_stages
299
-
300
- schedule = create_1f1b_schedule(
301
- num_stages=args.num_stages,
302
- num_batches=args.num_batches,
303
- forward_times=forward_times,
304
- backward_times=backward_times,
305
- )
306
-
307
- if args.interactive:
308
- visualize_pipeline_parallelism_dash(schedule, port=args.port)
309
- else:
310
- save_pipeline_visualization_plotly(schedule, output_file=args.output)
 
1
  import dash
2
  from dash import dcc, html
3
+ from dash.dependencies import Input, Output
4
  import plotly.graph_objects as go
5
+ import argparse
6
+ from typing import List, Dict, Literal, Optional
7
  from tqdm import tqdm
8
+ import base64
9
 
10
+ from src.execution_model import Schedule
11
 
12
+
13
+ def convert_schedule_to_visualization_format(schedule: Schedule):
14
+ """
15
+ Converts a Schedule object to the format needed for visualization.
16
+
17
+ Returns:
18
+ Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
19
+ """
20
+ # Make sure all operations have start and end times
21
+ for op in schedule.ops.values():
22
+ if op.start_time is None or op.end_time is None:
23
+ raise ValueError("Operations must have start and end times. Run ScheduleExecutor.execute() first.")
24
+
25
+ visualization_data = {}
26
+
27
+ # Organize operations by device
28
+ for device_id, device_queue in enumerate(schedule.dev_queues):
29
+ visualization_data[device_id] = []
30
+
31
+ for op in device_queue.ops:
32
+ visualization_data[device_id].append({
33
+ "type": op.op_type,
34
+ "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
35
+ "stage": op.stage_id,
36
+ "start_time": op.start_time,
37
+ "duration": op.end_time - op.start_time
38
+ })
39
+
40
+ return visualization_data
41
+
42
+
43
+ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True):
44
  """
45
  Create a Plotly figure for pipeline parallelism scheduling.
46
 
47
  Args:
48
+ schedule_data: Dictionary mapping device IDs to lists of tasks (converted from Schedule)
 
 
 
 
 
49
  max_time: Optional maximum time to display
50
  show_progress: Whether to show a progress bar
51
  """
52
+ # Find the number of devices
53
+ num_devices = len(schedule_data)
54
+
 
55
  empty_color = "whitesmoke"
56
+ # Colors for task types
57
+ def get_color(op_type: str, stage_id: int):
58
+ # Base colors
59
+ forward_base_color = "royalblue"
60
+ backward_base_color = "lightgreen" # Changed from sandybrown to match your visualization
61
+
62
+ virtual_stage = stage_id // num_devices
63
 
64
+ if op_type == "forward":
65
+ if virtual_stage == 0:
66
+ return forward_base_color
67
+ else:
68
+ # Lighter shade for virtual_stage > 0
69
+ return "lightskyblue"
70
+ elif op_type == "backward":
71
+ if virtual_stage == 0:
72
+ return backward_base_color
73
+ else:
74
+ # Lighter shade for virtual_stage > 0
75
+ return "lightseagreen"
76
+ else:
77
+ raise ValueError(f"Invalid operation type: {op_type}")
78
 
79
  # Find the maximum time in the schedule if not provided
80
  if max_time is None:
81
  max_time = 0
82
+ for device in schedule_data:
83
+ for task in schedule_data[device]:
84
  end_time = task["start_time"] + task["duration"]
85
  if end_time > max_time:
86
  max_time = end_time
 
89
  fig = go.Figure()
90
 
91
  # Initialize progress tracking
92
+ total_tasks = sum(len(tasks) for tasks in schedule_data.values())
93
  tasks_processed = 0
94
 
95
  if show_progress:
96
+ progress_bar = tqdm(total=total_tasks + num_devices + 3, desc="Creating visualization")
97
 
98
+ # Create a custom y-axis with no gaps between devices
99
+ y_spacing = 1.0 # Use 1.0 for no gaps
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # Add rectangles for each task
102
+ for device_idx, device in enumerate(schedule_data):
103
+ device_idx_reversed = num_devices - device_idx - 1
104
+
105
+ # Sort tasks by start time to ensure correct rendering
106
+ sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
107
 
108
+ for task in sorted_tasks:
109
  # Determine task color and text color
110
  if task["type"] == "forward":
111
+ color = get_color(task["type"], task["stage"])
112
  text_color = "white"
113
  name = "Forward"
114
  elif task["type"] == "backward":
115
+ color = get_color(task["type"], task["stage"])
116
  text_color = "black"
117
  name = "Backward"
118
+ else:
119
+ color = empty_color
120
  text_color = "black"
121
+ name = "Unknown"
122
+
123
  # Add rectangle for the task
124
  start_time = task["start_time"]
125
  duration = task["duration"]
126
 
127
+ # Calculate y positions with no gaps
128
+ y_pos = device_idx_reversed * y_spacing
129
+
130
  # Create rectangle using shape
131
  fig.add_shape(
132
  type="rect",
133
  x0=start_time,
134
+ y0=y_pos - 0.5,
135
  x1=start_time + duration,
136
+ y1=y_pos + 0.5,
137
  line=dict(color="black", width=0.5),
138
  fillcolor=color,
139
  layer="above",
 
142
  # Add batch number text
143
  fig.add_annotation(
144
  x=start_time + duration / 2,
145
+ y=y_pos,
146
+ text=f"{task['batch']}", # Only show batch ID
147
  showarrow=False,
148
+ font=dict(color=text_color, size=12, family="Arial, bold"), # Increased font size
149
  )
150
 
151
+ # Add hover data with additional details
152
+ fig.add_trace(go.Scatter(
153
+ x=[start_time + duration / 2],
154
+ y=[y_pos],
155
+ mode='markers',
156
+ marker=dict(opacity=0), # Invisible marker
157
+ hoverinfo='text',
158
+ text=f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}",
159
+ showlegend=False
160
+ ))
161
+
162
  # Update progress
163
  if show_progress:
164
  tasks_processed += 1
 
166
 
167
  # Add custom legend
168
  legend_items = [
169
+ dict(name="Forward", color=get_color("forward", 0)),
170
+ dict(name="Backward", color=get_color("backward", 0)),
 
171
  ]
172
 
173
  for i, item in enumerate(legend_items):
 
183
  progress_bar.update(1)
184
 
185
  # Set axis properties
186
+ device_labels = [f"Device {i}" for i in range(num_devices)]
187
+ device_labels.reverse() # Reverse to put Device 0 at the top
188
+
189
+ # Calculate tick positions with no gaps
190
+ tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
191
 
192
+ # Adjust the range to ensure there are no empty spaces at the end
193
+ x_end = max_time * 1.05 # Add a small margin
194
+
195
  fig.update_layout(
 
 
 
 
 
 
 
196
  yaxis=dict(
197
  tickmode="array",
198
+ tickvals=tick_positions,
199
  ticktext=device_labels,
200
  showgrid=False,
201
  zeroline=False,
 
202
  ),
203
+ margin=dict(l=50, r=20, t=40, b=40),
204
  plot_bgcolor="white",
205
+ title=dict(
206
+ text="Pipeline Parallelism Schedule",
207
+ x=0.5,
208
+ y=0.98, # Move title position closer to the top
209
+ font=dict(size=20)
210
+ ),
211
  legend=dict(
212
  orientation="h",
213
+ yanchor="top",
214
+ y=-0.1, # Position below the plot
215
  xanchor="center",
216
  x=0.5
217
+ ),
218
+ width=1600,
219
+ height=400, # Reduce height to make the visualization more compact
220
+ bargap=0,
221
+ bargroupgap=0,
222
  )
223
 
224
  if show_progress:
225
+ progress_bar.update(1)
226
  progress_bar.close()
227
 
228
  return fig
229
 
230
 
231
+ def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
232
  """
233
+ Create a Dash app to visualize the pipeline schedule.
234
+
235
  Args:
236
+ schedule: Schedule object to visualize
237
+ schedule_type: Type of schedule ("1f1b" or other)
238
  """
239
+ # Convert schedule to visualization format
240
+ schedule_data = convert_schedule_to_visualization_format(schedule)
241
+
242
+ # Create the app
243
+ app = dash.Dash(__name__, title=f"Pipeline Parallelism Visualizer - {schedule_type}")
244
 
245
  app.layout = html.Div([
246
+ html.H1(f"Pipeline Parallelism Visualizer - {schedule_type}", style={'textAlign': 'center'}),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  html.Div([
249
+ html.Div([
250
+ html.H3("Schedule Configuration:"),
251
+ html.Ul([
252
+ html.Li(f"Number of devices: {schedule.config.num_devices}"),
253
+ html.Li(f"Number of stages: {schedule.config.num_stages}"),
254
+ html.Li(f"Number of batches: {schedule.config.num_batches}"),
255
+ ]),
256
+ ], className="config-section"),
257
+
258
+ html.Button("Download Image", id="btn-download",
259
+ style={
260
+ 'marginTop': '20px',
261
+ 'padding': '10px',
262
+ 'backgroundColor': '#007BFF',
263
+ 'color': 'white',
264
+ 'border': 'none',
265
+ 'borderRadius': '5px',
266
+ 'cursor': 'pointer'
267
+ }),
268
+
269
+ dcc.Download(id="download-image"),
270
+ ], style={'margin': '20px'}),
271
+
272
+ html.Div(id="graph-container", children=[]),
273
+
274
+ dcc.Graph(
275
+ id="pipeline-graph",
276
+ config={'displayModeBar': True, 'toImageButtonOptions': {'format': 'png', 'filename': 'pipeline_visualization'}}
277
+ ),
278
  ])
279
 
280
  @app.callback(
 
284
  )
285
  def load_graph(_):
286
  # Create the figure when the app loads
287
+ return create_pipeline_figure(schedule_data, show_progress=True)
288
+
289
  @app.callback(
290
  Output("download-image", "data"),
291
  Input("btn-download", "n_clicks"),
292
  prevent_initial_call=True,
293
  )
294
  def download_image(n_clicks):
295
+ # Generate the figure for download
296
+ fig = create_pipeline_figure(schedule_data, show_progress=True)
297
+
298
+ # Convert to base64 image
299
+ img_bytes = fig.to_image(format="png", width=1600, height=1000, scale=2)
300
+ img_base64 = base64.b64encode(img_bytes).decode('ascii')
301
+
302
+ # Return the download data
303
  return dict(
304
+ content=img_base64,
305
+ filename=f"pipeline_visualization_{schedule_type}.png",
306
+ type="image/png",
307
+ base64=True
308
  )
309
 
310
  return app
311
 
312
 
313
  def visualize_pipeline_parallelism_dash(
314
+ schedule: Schedule,
 
315
  port: int = 8050,
316
  debug: bool = False
317
  ):
318
  """
319
+ Launch a Dash app to visualize the pipeline schedule interactively.
320
+
321
  Args:
322
+ schedule: Schedule object to visualize
323
+ port: Port to run the Dash app on
324
+ debug: Whether to run the Dash app in debug mode
 
325
  """
326
+ app = create_dash_app(schedule)
327
  print(f"Starting Dash app on http://localhost:{port}/")
328
  app.run_server(debug=debug, port=port)
329
 
330
 
331
  def save_pipeline_visualization_plotly(
332
+ schedule: Schedule,
 
333
  output_file: str = "pipeline_visualization_plotly.png",
334
  ):
335
  """
336
+ Save a static image of the pipeline schedule visualization.
337
+
338
  Args:
339
+ schedule: Schedule object to visualize
340
+ output_file: Path to save the image to
 
341
  """
342
+ schedule_data = convert_schedule_to_visualization_format(schedule)
343
+ fig = create_pipeline_figure(schedule_data, show_progress=True)
 
 
 
 
 
 
344
 
345
+ print(f"Saving visualization to {output_file}...")
346
+ fig.write_image(output_file, width=1600, height=400, scale=2)
 
347
  print(f"Visualization saved to {output_file}")
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
visualizer.py DELETED
@@ -1,141 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import numpy as np
3
- from matplotlib.patches import Rectangle
4
- from typing import List, Dict, Literal
5
-
6
-
7
- def visualize_pipeline_parallelism(
8
- schedule: Dict[int, List[Dict]],
9
- schedule_type: Literal["simple", "1f1b"] = "1f1b",
10
- output_file: str = "pipeline_visualization.png",
11
- ):
12
- """
13
- Visualize pipeline parallelism scheduling.
14
-
15
- Args:
16
- schedule: Dictionary mapping device IDs to lists of tasks.
17
- Each task is a dictionary with keys:
18
- - 'type': 'forward', 'backward', or 'optimizer'
19
- - 'batch': batch number
20
- - 'start_time': start time of the task
21
- - 'duration': duration of the task
22
- schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
23
- output_file: Path to save the visualization
24
- """
25
- # Colors for task types
26
- forward_color = "royalblue"
27
- backward_color = "sandybrown" # Changed to match the reference image
28
- optimizer_color = "#FFEFCF" # Light beige for optimizer steps
29
- empty_color = "whitesmoke" # Very light gray for empty cells
30
-
31
- # Find the number of stages (devices)
32
- num_stages = len(schedule)
33
-
34
- # Find the maximum time in the schedule
35
- max_time = 0
36
- for device in schedule:
37
- for task in schedule[device]:
38
- end_time = task["start_time"] + task["duration"]
39
- if end_time > max_time:
40
- max_time = end_time
41
-
42
- # Create figure and axis
43
- fig, ax = plt.subplots(figsize=(15, 4))
44
-
45
- # Create an empty grid with light gray color
46
- for device_idx in range(num_stages):
47
- device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
48
- for t in range(int(max_time) + 1):
49
- rect = Rectangle(
50
- (t, device_idx_reversed),
51
- 1.0,
52
- 1.0,
53
- edgecolor="lightgray",
54
- facecolor=empty_color,
55
- linewidth=0.5,
56
- )
57
- ax.add_patch(rect)
58
-
59
- # Plot the schedule
60
- for device_idx, device in enumerate(schedule):
61
- device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
62
- for task in schedule[device]:
63
- # Determine task color
64
- if task["type"] == "forward":
65
- color = forward_color
66
- text_color = "white"
67
- elif task["type"] == "backward":
68
- color = backward_color
69
- text_color = "black"
70
- else: # optimizer or any other type
71
- color = optimizer_color
72
- text_color = "black"
73
-
74
- rect = Rectangle(
75
- (task["start_time"], device_idx_reversed),
76
- task["duration"],
77
- 1.0,
78
- edgecolor="black",
79
- facecolor=color,
80
- linewidth=0.5,
81
- )
82
- ax.add_patch(rect)
83
-
84
- # Add text (batch number)
85
- ax.text(
86
- task["start_time"] + task["duration"] / 2,
87
- device_idx_reversed + 0.5,
88
- str(task["batch"]),
89
- ha="center",
90
- va="center",
91
- fontsize=10,
92
- fontweight="bold",
93
- color=text_color,
94
- )
95
-
96
- # Set axis limits and labels
97
- ax.set_xlim(0, max_time + 0.5)
98
- ax.set_ylim(-0.5, num_stages + 0.5)
99
- ax.set_yticks(np.arange(num_stages) + 0.5)
100
-
101
- # Reverse the order: Device 1 at the top, highest number at the bottom
102
- device_labels = [f"Device {i+1}" for i in range(num_stages)]
103
- device_labels.reverse() # Reverse to put Device 1 at the top
104
- ax.set_yticklabels(device_labels)
105
-
106
- # Add "Time" label and arrow at the bottom
107
- arrow_y = -0.4
108
- ax.text(0.5, arrow_y, "Time", ha="right", va="center", fontsize=10)
109
- ax.annotate("", xy=(2, arrow_y), xytext=(1, arrow_y),
110
- arrowprops=dict(arrowstyle="->", lw=1))
111
-
112
- # Remove the x-axis ticks
113
- ax.set_xticks([])
114
-
115
- # Remove the outer frame/border
116
- for spine in ax.spines.values():
117
- spine.set_visible(False)
118
-
119
- # Add a legend - using 3 parts like in the reference image
120
- forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
121
- backward_patch = Rectangle((0, 0), 1, 1, facecolor=backward_color)
122
- optimizer_patch = Rectangle((0, 0), 1, 1, facecolor=optimizer_color)
123
-
124
- legend = ax.legend(
125
- [forward_patch, backward_patch, optimizer_patch],
126
- ["Forward", "Backward", "Optimizer step"],
127
- loc="upper center",
128
- bbox_to_anchor=(0.5, -0.15),
129
- ncol=3,
130
- frameon=False,
131
- )
132
-
133
- # Turn off grid
134
- ax.grid(False)
135
-
136
- # Save the figure
137
- plt.tight_layout()
138
- plt.savefig(output_file, dpi=300, bbox_inches="tight")
139
- plt.close()
140
-
141
- print(f"Visualization saved to {output_file}")