Spaces:
Running
Running
Add VPP support and refactor project.
Browse files- .gitignore +9 -77
- LICENSE +21 -0
- README-dash-visualizer.md +0 -91
- README.md +80 -62
- conf/config.yaml +22 -0
- configs/standard.json +0 -8
- main.py +62 -0
- pipeline.py +0 -491
- pipeline_1f1b.png +0 -3
- pyproject.toml +67 -0
- requirements-dash.txt +0 -5
- src/__init__.py +3 -0
- src/execution_model.py +219 -0
- src/strategies.py +192 -0
- dash_visualizer.py → src/visualizer.py +195 -157
- visualizer.py +0 -141
.gitignore
CHANGED
@@ -1,78 +1,10 @@
|
|
1 |
# Python
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
eggs/
|
12 |
-
.eggs/
|
13 |
-
lib/
|
14 |
-
lib64/
|
15 |
-
parts/
|
16 |
-
sdist/
|
17 |
-
var/
|
18 |
-
wheels/
|
19 |
-
*.egg-info/
|
20 |
-
.installed.cfg
|
21 |
-
*.egg
|
22 |
-
|
23 |
-
# Virtual Environment
|
24 |
-
venv/
|
25 |
-
env/
|
26 |
-
ENV/
|
27 |
-
.env
|
28 |
-
|
29 |
-
# IDE specific files
|
30 |
-
.idea/
|
31 |
-
.vscode/
|
32 |
-
*.swp
|
33 |
-
*.swo
|
34 |
-
.DS_Store
|
35 |
-
|
36 |
-
# Jupyter Notebook
|
37 |
-
.ipynb_checkpoints
|
38 |
-
|
39 |
-
# Distribution / packaging
|
40 |
-
.Python
|
41 |
-
env/
|
42 |
-
build/
|
43 |
-
develop-eggs/
|
44 |
-
dist/
|
45 |
-
downloads/
|
46 |
-
eggs/
|
47 |
-
.eggs/
|
48 |
-
lib/
|
49 |
-
lib64/
|
50 |
-
parts/
|
51 |
-
sdist/
|
52 |
-
var/
|
53 |
-
wheels/
|
54 |
-
*.egg-info/
|
55 |
-
.installed.cfg
|
56 |
-
*.egg
|
57 |
-
|
58 |
-
# Unit test / coverage reports
|
59 |
-
htmlcov/
|
60 |
-
.tox/
|
61 |
-
.coverage
|
62 |
-
.coverage.*
|
63 |
-
.cache
|
64 |
-
nosetests.xml
|
65 |
-
coverage.xml
|
66 |
-
*.cover
|
67 |
-
.hypothesis/
|
68 |
-
|
69 |
-
# Pipeline visualization outputs
|
70 |
-
*.png
|
71 |
-
*.jpg
|
72 |
-
*.jpeg
|
73 |
-
*.pdf
|
74 |
-
*.svg
|
75 |
-
|
76 |
-
# Local configuration
|
77 |
-
config.ini
|
78 |
-
secrets.json
|
|
|
1 |
# Python
|
2 |
+
./venv
|
3 |
+
uv.lock
|
4 |
+
outputs/
|
5 |
+
|
6 |
+
# Uncomment below if you want to include these files
|
7 |
+
# !assets/*.png
|
8 |
+
# !assets/*.jpg
|
9 |
+
# !docs/*.png
|
10 |
+
# !docs/*.jpg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README-dash-visualizer.md
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
# Pipeline Parallelism Dash Visualizer
|
2 |
-
|
3 |
-
This is an interactive Dash-based visualizer for pipeline parallelism scheduling, complementing the existing Matplotlib-based visualization.
|
4 |
-
|
5 |
-
## Features
|
6 |
-
|
7 |
-
- **Static image generation** similar to the Matplotlib version
|
8 |
-
- **Interactive web-based visualization** with Dash
|
9 |
-
- **Download functionality** to save the visualization as PNG
|
10 |
-
- **Progress indication** during figure creation and image generation
|
11 |
-
- **Compatible API** with the existing visualizer
|
12 |
-
|
13 |
-
## Installation
|
14 |
-
|
15 |
-
Install the required dependencies:
|
16 |
-
|
17 |
-
```bash
|
18 |
-
pip install -r requirements-dash.txt
|
19 |
-
```
|
20 |
-
|
21 |
-
## Usage
|
22 |
-
|
23 |
-
### From Python
|
24 |
-
|
25 |
-
```python
|
26 |
-
from pipeline import create_1f1b_schedule
|
27 |
-
from dash_visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
|
28 |
-
|
29 |
-
# Create a schedule
|
30 |
-
schedule = create_1f1b_schedule(
|
31 |
-
num_stages=4,
|
32 |
-
num_batches=8,
|
33 |
-
forward_times=[1.0, 1.0, 1.0, 1.0],
|
34 |
-
backward_times=[2.0, 2.0, 2.0, 2.0],
|
35 |
-
)
|
36 |
-
|
37 |
-
# Generate a static image
|
38 |
-
save_pipeline_visualization_plotly(
|
39 |
-
schedule=schedule,
|
40 |
-
schedule_type="1f1b",
|
41 |
-
output_file="pipeline_plotly.png"
|
42 |
-
)
|
43 |
-
|
44 |
-
# OR launch an interactive Dash app
|
45 |
-
visualize_pipeline_parallelism_dash(
|
46 |
-
schedule=schedule,
|
47 |
-
schedule_type="1f1b",
|
48 |
-
port=8050,
|
49 |
-
debug=False
|
50 |
-
)
|
51 |
-
```
|
52 |
-
|
53 |
-
### Using the Command Line
|
54 |
-
|
55 |
-
You can use the updated command line interface:
|
56 |
-
|
57 |
-
```bash
|
58 |
-
# Generate a static image with Dash/Plotly
|
59 |
-
python pipeline.py --visualizer dash --output-file pipeline_viz.png
|
60 |
-
|
61 |
-
# Launch an interactive Dash app
|
62 |
-
python pipeline.py --visualizer dash-interactive
|
63 |
-
|
64 |
-
# Use the original Matplotlib visualizer
|
65 |
-
python pipeline.py --visualizer matplotlib
|
66 |
-
```
|
67 |
-
|
68 |
-
You can also use the dash_visualizer.py script directly for testing:
|
69 |
-
|
70 |
-
```bash
|
71 |
-
# Generate a static image
|
72 |
-
python dash_visualizer.py --output test_viz.png
|
73 |
-
|
74 |
-
# Launch an interactive app
|
75 |
-
python dash_visualizer.py --interactive
|
76 |
-
```
|
77 |
-
|
78 |
-
## Differences from Matplotlib Visualizer
|
79 |
-
|
80 |
-
The Dash-based visualizer provides all the same visual elements as the Matplotlib version:
|
81 |
-
- Color-coded rectangles for forward, backward, and optimizer operations
|
82 |
-
- Batch numbers displayed inside each rectangle
|
83 |
-
- Device labels on the y-axis
|
84 |
-
- Clear legend
|
85 |
-
|
86 |
-
Additional features:
|
87 |
-
- Interactive web interface
|
88 |
-
- Hovering over elements to see details
|
89 |
-
- Download button to save the visualization
|
90 |
-
- Progress bars for tracking visualization creation
|
91 |
-
- Responsive layout that works well on different screen sizes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,77 +1,95 @@
|
|
1 |
-
# Pipeline Parallelism
|
2 |
|
3 |
-
This
|
4 |
|
5 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
|
|
9 |
```bash
|
10 |
-
python
|
11 |
```
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
| Option | Short | Description |
|
17 |
-
|--------|-------|-------------|
|
18 |
-
| `--config` | `-c` | Path to config file (JSON or YAML) |
|
19 |
-
| `--num-stages` | `-s` | Number of pipeline stages (devices) |
|
20 |
-
| `--num-batches` | `-b` | Number of micro-batches |
|
21 |
-
| `--forward-times` | `-f` | Time for forward pass at each stage (space-separated list) |
|
22 |
-
| `--backward-times` | `-bw` | Time for backward pass at each stage (space-separated list) |
|
23 |
-
| `--output` | `-o` | Output file path for visualization |
|
24 |
-
| `--no-visualization` | | Skip visualization generation |
|
25 |
-
| `--p2p-time`| | P2P communication time of PP |
|
26 |
-
|
27 |
-
### Using Configuration Files
|
28 |
-
|
29 |
-
You can use either JSON or YAML configuration files:
|
30 |
-
|
31 |
-
Example JSON configuration (sample_config.json):
|
32 |
-
```json
|
33 |
-
{
|
34 |
-
"num_stages": 6,
|
35 |
-
"num_batches": 12,
|
36 |
-
"forward_times": [0.8, 1.0, 1.2, 1.0, 0.9, 1.1],
|
37 |
-
"backward_times": [1.6, 2.0, 2.4, 2.0, 1.8, 2.2],
|
38 |
-
"output_file": "pipeline_1f1b_custom.png"
|
39 |
-
}
|
40 |
```
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
```
|
61 |
|
62 |
-
##
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
1. **Warmup Phase**: Forward passes for the first several micro-batches
|
68 |
-
2. **Steady State**: Each device alternates between forward and backward passes
|
69 |
-
3. **Cooldown Phase**: Backward passes to complete the computation for remaining micro-batches
|
70 |
|
71 |
-
|
72 |
|
73 |
-
##
|
74 |
|
75 |
-
|
76 |
-
- GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism (NeurIPS'19)
|
77 |
-
- Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
|
|
|
1 |
+
# Pipeline Parallelism Emulation
|
2 |
|
3 |
+
This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
|
4 |
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
Pipeline parallelism is a technique used to train large models by partitioning the model across multiple devices and processing data in a pipelined fashion. This project allows you to:
|
8 |
+
|
9 |
+
- Simulate different pipeline parallelism strategies (1F1B, Interleaved)
|
10 |
+
- Visualize the execution schedule on multiple devices
|
11 |
+
- Compare different strategies for efficiency
|
12 |
+
|
13 |
+
## Features
|
14 |
+
- Supported Pipeline Stragegies:
|
15 |
+
- 1F1B
|
16 |
+
- Interleaved 1F1B
|
17 |
+
- Visualization:
|
18 |
+
- Interactive visualization dashboard using Plotly/Dash
|
19 |
+
- Config:
|
20 |
+
- Configurable simulation parameters through Hydra
|
21 |
+
- Each stage
|
22 |
+
|
23 |
+
## Installation
|
24 |
+
|
25 |
+
This project uses [uv](https://github.com/astral-sh/uv) for dependency management.
|
26 |
|
27 |
+
Setup `uv` if not installed in your computer:
|
28 |
+
```
|
29 |
+
# On macOS and Linux.
|
30 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
31 |
+
```
|
32 |
+
|
33 |
+
## Usage
|
34 |
|
35 |
+
Running for 1F1B strategy:
|
36 |
```bash
|
37 |
+
uv run python main.py strategy=1f1b num_devices=4 num_stages=4 num_batches=8
|
38 |
```
|
39 |
+
|
40 |
+
```bash
|
41 |
+
uv run python main.py strategy=interleave num_devices=4 num_stages=8 num_batches=8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
```
|
43 |
|
44 |
+
## Configuration
|
45 |
+
|
46 |
+
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
47 |
+
|
48 |
+
### Using Different Configuration Files
|
49 |
+
|
50 |
+
You can use different configuration files with Hydra in several ways:
|
51 |
+
|
52 |
+
#### Recommended Approach
|
53 |
+
|
54 |
+
1. Create multiple configuration files in the `conf` directory for different use cases:
|
55 |
+
```
|
56 |
+
conf/
|
57 |
+
├── config.yaml # Default configuration
|
58 |
+
└── model_A.yaml # Create your own config with stage-specific latency for performance projection.
|
59 |
+
```
|
60 |
+
|
61 |
+
2. Run with your desired configuration using the `--config-name` flag:
|
62 |
+
```bash
|
63 |
+
uv run python main.py --config-name=model_A
|
64 |
+
```
|
65 |
+
|
66 |
+
#### Override Specific Parameters
|
67 |
+
|
68 |
+
You can also override specific parameters at runtime:
|
69 |
+
```bash
|
70 |
+
uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
|
71 |
```
|
72 |
|
73 |
+
## Project Structure
|
74 |
|
75 |
+
```
|
76 |
+
PP-Emulation/
|
77 |
+
├── conf/ # Hydra configuration files
|
78 |
+
│ └── config.yaml # Default configuration
|
79 |
+
├── src/ # Source code
|
80 |
+
│ ├── __init__.py # Package initialization
|
81 |
+
│ ├── execution_model.py # Schedule execution models
|
82 |
+
│ ├── strategies.py # Pipeline parallelism strategies
|
83 |
+
│ └── visualizer.py # Visualization utilities
|
84 |
+
├── main.py # Main entry point
|
85 |
+
├── pyproject.toml # Project metadata and dependencies
|
86 |
+
└── README.md # This file
|
87 |
+
```
|
88 |
|
89 |
+
## License
|
|
|
|
|
|
|
90 |
|
91 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
92 |
|
93 |
+
## Contributing
|
94 |
|
95 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
|
|
|
conf/config.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default configuration for Pipeline Parallelism Emulation
|
2 |
+
num_devices: 4
|
3 |
+
num_stages: 4
|
4 |
+
num_batches: 12
|
5 |
+
visualization_port: 8050
|
6 |
+
strategy: "1f1b" # Options: "1f1b", "interleave"
|
7 |
+
p2p_latency: 0.0
|
8 |
+
|
9 |
+
# Operation time configurations
|
10 |
+
op_times:
|
11 |
+
# Option 1: Simple configuration (same time for all stages)
|
12 |
+
forward: 1.0
|
13 |
+
backward: 2.0
|
14 |
+
|
15 |
+
# Option 2: Commented example of stage-specific configuration
|
16 |
+
# forward:
|
17 |
+
# 0: 0.8 # Stage 0 forward time
|
18 |
+
# 1: 1.2 # Stage 1 forward time
|
19 |
+
# 2: 1.5 # Stage 2 forward time
|
20 |
+
# 3: 0.9 # Stage 3 forward time
|
21 |
+
# backward:
|
22 |
+
# 0: 1.0 # Stage 0 backward time
|
configs/standard.json
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"num_stages": 4,
|
3 |
-
"num_batches": 8,
|
4 |
-
"forward_times": [1.0, 1.0, 1.0, 1.0],
|
5 |
-
"backward_times": [2.0, 2.0, 2.0, 2.0],
|
6 |
-
"output_file": "pipeline_1f1b.png",
|
7 |
-
"p2p_time": 0.0
|
8 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.execution_model import ScheduleConfig, ScheduleExecutor
|
2 |
+
from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
|
3 |
+
from src.visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
|
4 |
+
import hydra
|
5 |
+
from omegaconf import DictConfig, OmegaConf
|
6 |
+
|
7 |
+
|
8 |
+
@hydra.main(config_path="conf", config_name="config", version_base=None)
|
9 |
+
def main(cfg: DictConfig) -> None:
|
10 |
+
"""Run pipeline parallelism simulation with the specified configuration."""
|
11 |
+
print(f"Running with configuration: {cfg}")
|
12 |
+
|
13 |
+
if cfg.strategy == "1f1b":
|
14 |
+
run_1f1b(cfg)
|
15 |
+
elif cfg.strategy == "interleave":
|
16 |
+
run_interleave(cfg)
|
17 |
+
else:
|
18 |
+
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
19 |
+
|
20 |
+
|
21 |
+
def run_1f1b(cfg: DictConfig) -> None:
|
22 |
+
"""Run 1F1B pipeline parallelism simulation."""
|
23 |
+
# Convert OmegaConf to dict for op_times if it exists
|
24 |
+
op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
|
25 |
+
|
26 |
+
schedule_config = ScheduleConfig(
|
27 |
+
num_devices=cfg.num_devices,
|
28 |
+
num_stages=cfg.num_stages,
|
29 |
+
num_batches=cfg.num_batches,
|
30 |
+
p2p_latency=cfg.p2p_latency,
|
31 |
+
op_times=op_times,
|
32 |
+
placement_strategy="1f1b"
|
33 |
+
)
|
34 |
+
schedule = generate_1f1b_schedule(schedule_config)
|
35 |
+
executor = ScheduleExecutor(schedule)
|
36 |
+
executor.execute()
|
37 |
+
|
38 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
39 |
+
|
40 |
+
|
41 |
+
def run_interleave(cfg: DictConfig) -> None:
|
42 |
+
"""Run interleaved pipeline parallelism simulation."""
|
43 |
+
# Convert OmegaConf to dict for op_times if it exists
|
44 |
+
op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
|
45 |
+
|
46 |
+
schedule_config = ScheduleConfig(
|
47 |
+
num_devices=cfg.num_devices,
|
48 |
+
num_stages=cfg.num_stages,
|
49 |
+
num_batches=cfg.num_batches,
|
50 |
+
p2p_latency=cfg.p2p_latency,
|
51 |
+
placement_strategy="interleave",
|
52 |
+
op_times=op_times
|
53 |
+
)
|
54 |
+
schedule = generate_1f1b_interleave_schedule(schedule_config)
|
55 |
+
executor = ScheduleExecutor(schedule)
|
56 |
+
executor.execute()
|
57 |
+
|
58 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
main()
|
pipeline.py
DELETED
@@ -1,491 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import json
|
3 |
-
import yaml
|
4 |
-
import os
|
5 |
-
from typing import List, Dict
|
6 |
-
|
7 |
-
# Import visualization function from the new module
|
8 |
-
from visualizer import visualize_pipeline_parallelism
|
9 |
-
try:
|
10 |
-
from dash_visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
|
11 |
-
DASH_AVAILABLE = True
|
12 |
-
except ImportError:
|
13 |
-
DASH_AVAILABLE = False
|
14 |
-
|
15 |
-
|
16 |
-
def create_1f1b_schedule(
|
17 |
-
num_stages: int,
|
18 |
-
num_batches: int,
|
19 |
-
forward_times: List[float],
|
20 |
-
backward_times: List[float],
|
21 |
-
p2p_time: float = 0.0,
|
22 |
-
) -> Dict[int, List[Dict]]:
|
23 |
-
"""
|
24 |
-
Create a 1F1B (One-Forward-One-Backward) schedule for pipeline parallelism.
|
25 |
-
|
26 |
-
This implementation takes a data-centric approach:
|
27 |
-
1. First determine the operation sequence for each pipeline stage (which microbatch to process when)
|
28 |
-
2. Then calculate timing based on dependencies between operations
|
29 |
-
|
30 |
-
The 1F1B pattern has three phases:
|
31 |
-
- Warmup: Forward passes for first num_stages microbatches
|
32 |
-
- Steady state: Alternating between forward and backward passes
|
33 |
-
- Cooldown: Backward passes for remaining microbatches
|
34 |
-
|
35 |
-
Returns:
|
36 |
-
A dictionary mapping device IDs to lists of tasks.
|
37 |
-
Each task is a dictionary with keys:
|
38 |
-
- 'type': 'forward' or 'backward'
|
39 |
-
- 'batch': batch number
|
40 |
-
- 'start_time': start time of the task
|
41 |
-
- 'duration': duration of the task
|
42 |
-
"""
|
43 |
-
# Initialize empty schedule
|
44 |
-
schedule = {stage: [] for stage in range(num_stages)}
|
45 |
-
|
46 |
-
# Step 1: Determine operation sequence for each stage
|
47 |
-
# This will generate the sequence of operations (forward/backward on which microbatch)
|
48 |
-
# that each stage should perform, without timing information yet
|
49 |
-
operation_sequence = determine_1f1b_operation_sequence(num_stages, num_batches)
|
50 |
-
|
51 |
-
# Step 2: Convert operation sequence to schedule with timing
|
52 |
-
# Taking into account dependencies between operations
|
53 |
-
schedule = calculate_operation_timing(
|
54 |
-
operation_sequence, num_stages, forward_times, backward_times, p2p_time
|
55 |
-
)
|
56 |
-
|
57 |
-
return schedule
|
58 |
-
|
59 |
-
|
60 |
-
def determine_1f1b_operation_sequence(
|
61 |
-
num_stages: int, num_batches: int
|
62 |
-
) -> Dict[int, List[Dict]]:
|
63 |
-
"""
|
64 |
-
Determine the sequence of operations (forward/backward) for each stage in 1F1B scheduling.
|
65 |
-
|
66 |
-
Args:
|
67 |
-
num_stages: Number of pipeline stages
|
68 |
-
num_batches: Number of micro-batches
|
69 |
-
|
70 |
-
Returns:
|
71 |
-
Dictionary mapping stage ID to a list of operations in sequence.
|
72 |
-
Each operation is a dict with keys 'type' ('forward' or 'backward') and 'batch'.
|
73 |
-
"""
|
74 |
-
operation_sequence = {i: [] for i in range(num_stages)}
|
75 |
-
for current_stage in range(num_stages):
|
76 |
-
warmup_batches = num_stages - current_stage
|
77 |
-
for j in range(1, warmup_batches + 1):
|
78 |
-
operation_sequence[current_stage].append({"type": "forward", "batch": j})
|
79 |
-
steady_batches = num_batches - warmup_batches
|
80 |
-
for j in range(warmup_batches + 1, warmup_batches + steady_batches + 1):
|
81 |
-
operation_sequence[current_stage].append(
|
82 |
-
{"type": "backward", "batch": j - warmup_batches}
|
83 |
-
)
|
84 |
-
operation_sequence[current_stage].append({"type": "forward", "batch": j})
|
85 |
-
for j in range(warmup_batches):
|
86 |
-
operation_sequence[current_stage].append(
|
87 |
-
{"type": "backward", "batch": j + steady_batches + 1}
|
88 |
-
)
|
89 |
-
|
90 |
-
return operation_sequence
|
91 |
-
|
92 |
-
|
93 |
-
def calculate_operation_timing(
|
94 |
-
operation_sequence: Dict[int, List[Dict]],
|
95 |
-
num_stages: int,
|
96 |
-
forward_times: List[float],
|
97 |
-
backward_times: List[float],
|
98 |
-
p2p_time: float = 0.0,
|
99 |
-
) -> Dict[int, List[Dict]]:
|
100 |
-
"""
|
101 |
-
Recursively calculate the specific timing of each operation in a 1F1B schedule.
|
102 |
-
|
103 |
-
When encountering an operation that depends on a previous operation that hasn't been calculated yet,
|
104 |
-
it will recursively calculate the timing of those operations.
|
105 |
-
|
106 |
-
Args:
|
107 |
-
operation_sequence: Operation sequence for each stage
|
108 |
-
num_stages: Number of pipeline stages
|
109 |
-
forward_times: Forward propagation time for each stage
|
110 |
-
backward_times: Backward propagation time for each stage
|
111 |
-
p2p_time: Point-to-point communication time between stages
|
112 |
-
|
113 |
-
Returns:
|
114 |
-
Complete schedule with timing information, each operation includes start_time and duration
|
115 |
-
"""
|
116 |
-
# Initialize schedule with timing information
|
117 |
-
schedule = {i: [] for i in range(num_stages)}
|
118 |
-
|
119 |
-
# For recording already computed operation end times
|
120 |
-
# Format: {(stage, batch, op_type): (start_time, end_time)}
|
121 |
-
computed_ops = {}
|
122 |
-
|
123 |
-
# For recording the end time of the last operation for each stage
|
124 |
-
stage_last_end_time = [0.0] * num_stages
|
125 |
-
|
126 |
-
# Helper function: recursively calculate the time for an operation
|
127 |
-
def compute_op_time(stage, batch, op_type):
|
128 |
-
# Check if this operation has already been calculated
|
129 |
-
key = (stage, batch, op_type)
|
130 |
-
if key in computed_ops:
|
131 |
-
return computed_ops[key]
|
132 |
-
|
133 |
-
# Get operation duration
|
134 |
-
duration = (
|
135 |
-
forward_times[stage] if op_type == "forward" else backward_times[stage]
|
136 |
-
)
|
137 |
-
|
138 |
-
# Determine start time (dependent on other operations)
|
139 |
-
# 1. Consider sequential dependencies on the stage (must wait for previous operation to complete)
|
140 |
-
start_time = stage_last_end_time[stage]
|
141 |
-
|
142 |
-
# 2. Forward pass also depends on forward pass of previous stage (if not the first stage)
|
143 |
-
if op_type == "forward" and stage > 0:
|
144 |
-
# Recursively calculate the time for the forward pass of the previous stage (if not calculated yet)
|
145 |
-
prev_stage_key = (stage - 1, batch, "forward")
|
146 |
-
if prev_stage_key not in computed_ops:
|
147 |
-
prev_start, prev_end = compute_op_time(stage - 1, batch, "forward")
|
148 |
-
else:
|
149 |
-
_, prev_end = computed_ops[prev_stage_key]
|
150 |
-
# Update start time
|
151 |
-
start_time = max(start_time, prev_end + p2p_time)
|
152 |
-
|
153 |
-
# 3. Backward pass depends on:
|
154 |
-
elif op_type == "backward":
|
155 |
-
# a. Forward pass of the same stage
|
156 |
-
same_stage_forward_key = (stage, batch, "forward")
|
157 |
-
if same_stage_forward_key not in computed_ops:
|
158 |
-
_, forward_end = compute_op_time(stage, batch, "forward")
|
159 |
-
else:
|
160 |
-
_, forward_end = computed_ops[same_stage_forward_key]
|
161 |
-
|
162 |
-
start_time = max(start_time, forward_end)
|
163 |
-
|
164 |
-
# b. Backward pass of the next stage (if not the last stage)
|
165 |
-
if stage < num_stages - 1:
|
166 |
-
next_stage_backward_key = (stage + 1, batch, "backward")
|
167 |
-
if next_stage_backward_key not in computed_ops:
|
168 |
-
_, next_backward_end = compute_op_time(stage + 1, batch, "backward")
|
169 |
-
else:
|
170 |
-
_, next_backward_end = computed_ops[next_stage_backward_key]
|
171 |
-
|
172 |
-
start_time = max(start_time, next_backward_end + p2p_time)
|
173 |
-
|
174 |
-
# Calculate end time
|
175 |
-
end_time = start_time + duration
|
176 |
-
|
177 |
-
# Store calculation results
|
178 |
-
computed_ops[key] = (start_time, end_time)
|
179 |
-
|
180 |
-
# Update the end time of the last operation for this stage
|
181 |
-
stage_last_end_time[stage] = end_time
|
182 |
-
|
183 |
-
return start_time, end_time
|
184 |
-
|
185 |
-
# Calculate time for each operation in the operation_sequence
|
186 |
-
for i in range(len(operation_sequence[0])):
|
187 |
-
for stage in range(num_stages):
|
188 |
-
batch = operation_sequence[stage][i]["batch"]
|
189 |
-
op_type = operation_sequence[stage][i]["type"]
|
190 |
-
|
191 |
-
# Recursively calculate the time for this operation
|
192 |
-
start_time, _ = compute_op_time(stage, batch, op_type)
|
193 |
-
|
194 |
-
# Fill in scheduling information
|
195 |
-
op_with_timing = operation_sequence[stage][i].copy()
|
196 |
-
op_with_timing["start_time"] = start_time
|
197 |
-
op_with_timing["duration"] = (
|
198 |
-
forward_times[stage] if op_type == "forward" else backward_times[stage]
|
199 |
-
)
|
200 |
-
schedule[stage].append(op_with_timing)
|
201 |
-
|
202 |
-
return schedule
|
203 |
-
|
204 |
-
|
205 |
-
def get_schedule_info(schedule: Dict[int, List[Dict]]):
|
206 |
-
num_stages = len(schedule)
|
207 |
-
|
208 |
-
max_time = 0
|
209 |
-
for device in schedule:
|
210 |
-
for task in schedule[device]:
|
211 |
-
end_time = task["start_time"] + task["duration"]
|
212 |
-
if end_time > max_time:
|
213 |
-
max_time = end_time
|
214 |
-
|
215 |
-
total_execution_time = max_time * num_stages
|
216 |
-
|
217 |
-
total_computation_time = 0
|
218 |
-
device_computation_times = {}
|
219 |
-
|
220 |
-
for device in schedule:
|
221 |
-
device_computation_time = 0
|
222 |
-
for task in schedule[device]:
|
223 |
-
device_computation_time += task["duration"]
|
224 |
-
device_computation_times[device] = device_computation_time
|
225 |
-
total_computation_time += device_computation_time
|
226 |
-
|
227 |
-
bubble_rate = (
|
228 |
-
total_execution_time - total_computation_time
|
229 |
-
) / total_computation_time
|
230 |
-
|
231 |
-
return {
|
232 |
-
"bubble_rate": f"{bubble_rate*100:.2f}%",
|
233 |
-
"execution_time": f"{max_time / 1000:.2f} s",
|
234 |
-
}
|
235 |
-
|
236 |
-
|
237 |
-
def read_config_file(config_path):
|
238 |
-
"""
|
239 |
-
Read configuration from a JSON or YAML file.
|
240 |
-
|
241 |
-
Args:
|
242 |
-
config_path: Path to the config file (JSON or YAML)
|
243 |
-
|
244 |
-
Returns:
|
245 |
-
Dictionary containing configuration parameters
|
246 |
-
"""
|
247 |
-
if not os.path.exists(config_path):
|
248 |
-
raise FileNotFoundError(f"Config file not found: {config_path}")
|
249 |
-
|
250 |
-
file_ext = os.path.splitext(config_path)[1].lower()
|
251 |
-
|
252 |
-
try:
|
253 |
-
with open(config_path, "r") as f:
|
254 |
-
if file_ext == ".json":
|
255 |
-
config = json.load(f)
|
256 |
-
elif file_ext in (".yaml", ".yml"):
|
257 |
-
config = yaml.safe_load(f)
|
258 |
-
else:
|
259 |
-
raise ValueError(
|
260 |
-
f"Unsupported config file format: {file_ext}. Use .json, .yaml, or .yml"
|
261 |
-
)
|
262 |
-
return config
|
263 |
-
except Exception as e:
|
264 |
-
raise ValueError(f"Error reading config file: {str(e)}")
|
265 |
-
|
266 |
-
|
267 |
-
def parse_args():
|
268 |
-
"""
|
269 |
-
Parse command-line arguments for the pipeline parallelism tool.
|
270 |
-
|
271 |
-
Returns:
|
272 |
-
Parsed arguments namespace
|
273 |
-
"""
|
274 |
-
parser = argparse.ArgumentParser(
|
275 |
-
description="Pipeline Parallelism Scheduler and Visualizer",
|
276 |
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
277 |
-
)
|
278 |
-
|
279 |
-
# Config file option
|
280 |
-
parser.add_argument(
|
281 |
-
"--config", "-c", type=str, help="Path to config file (JSON or YAML)"
|
282 |
-
)
|
283 |
-
|
284 |
-
# Main parameters
|
285 |
-
parser.add_argument(
|
286 |
-
"--num-stages",
|
287 |
-
"-s",
|
288 |
-
type=int,
|
289 |
-
default=0,
|
290 |
-
help="Number of pipeline stages (devices)",
|
291 |
-
)
|
292 |
-
|
293 |
-
parser.add_argument(
|
294 |
-
"--num-batches", "-b", type=int, default=0, help="Number of micro-batches"
|
295 |
-
)
|
296 |
-
|
297 |
-
# Forward and backward times
|
298 |
-
parser.add_argument(
|
299 |
-
"--forward-times",
|
300 |
-
"-f",
|
301 |
-
type=float,
|
302 |
-
nargs="+",
|
303 |
-
help="Time for forward pass at each stage (space-separated list)",
|
304 |
-
)
|
305 |
-
|
306 |
-
parser.add_argument(
|
307 |
-
"--backward-times",
|
308 |
-
"-bw",
|
309 |
-
type=float,
|
310 |
-
nargs="+",
|
311 |
-
help="Time for backward pass at each stage (space-separated list)",
|
312 |
-
)
|
313 |
-
|
314 |
-
# Output options
|
315 |
-
parser.add_argument(
|
316 |
-
"--output",
|
317 |
-
"-o",
|
318 |
-
type=str,
|
319 |
-
default="pipeline_1f1b.png",
|
320 |
-
help="Output file path for visualization",
|
321 |
-
)
|
322 |
-
|
323 |
-
parser.add_argument(
|
324 |
-
"--no-visualization", action="store_true", help="Skip visualization generation"
|
325 |
-
)
|
326 |
-
|
327 |
-
parser.add_argument(
|
328 |
-
"--p2p-time",
|
329 |
-
type=float,
|
330 |
-
default=0.0,
|
331 |
-
help="Time for point-to-point communication between stages",
|
332 |
-
)
|
333 |
-
|
334 |
-
parser.add_argument("--visualizer", choices=["matplotlib", "dash", "dash-interactive"],
|
335 |
-
default="matplotlib", help="Visualization library to use")
|
336 |
-
|
337 |
-
return parser.parse_args()
|
338 |
-
|
339 |
-
|
340 |
-
def example_usage():
|
341 |
-
"""Example usage of the visualization function and testing the scheduling algorithms."""
|
342 |
-
# Example parameters
|
343 |
-
num_stages = 4 # Number of pipeline stages (devices)
|
344 |
-
num_batches = 10 # Number of micro-batches
|
345 |
-
|
346 |
-
# Example times for forward and backward passes for each stage
|
347 |
-
forward_times = [1.0, 1.0, 1.0, 1.0] # Time for forward pass at each stage
|
348 |
-
backward_times = [2.0, 2.0, 2.0, 2.0] # Time for backward pass at each stage
|
349 |
-
|
350 |
-
# Create 1F1B schedule
|
351 |
-
schedule = create_1f1b_schedule(
|
352 |
-
num_stages=num_stages,
|
353 |
-
num_batches=num_batches,
|
354 |
-
forward_times=forward_times,
|
355 |
-
backward_times=backward_times,
|
356 |
-
)
|
357 |
-
|
358 |
-
# Create visualization with the schedule
|
359 |
-
visualize_pipeline_parallelism(
|
360 |
-
schedule=schedule, schedule_type="1f1b", output_file="pipeline_1f1b.png"
|
361 |
-
)
|
362 |
-
|
363 |
-
# Analyze the schedule
|
364 |
-
schedule_info = get_schedule_info(schedule)
|
365 |
-
print(schedule_info)
|
366 |
-
|
367 |
-
|
368 |
-
def main():
|
369 |
-
"""
|
370 |
-
Main function that parses arguments and runs the pipeline parallelism analysis.
|
371 |
-
"""
|
372 |
-
args = parse_args()
|
373 |
-
|
374 |
-
# Initialize with default values
|
375 |
-
num_stages = 4
|
376 |
-
num_batches = 10
|
377 |
-
forward_times = None
|
378 |
-
backward_times = None
|
379 |
-
output_file = "pipeline_1f1b.png"
|
380 |
-
p2p_time = 0.0
|
381 |
-
|
382 |
-
# Command line arguments override config file
|
383 |
-
num_stages = args.num_stages
|
384 |
-
num_batches = args.num_batches
|
385 |
-
forward_times = args.forward_times
|
386 |
-
backward_times = args.backward_times
|
387 |
-
output_file = args.output
|
388 |
-
p2p_time = args.p2p_time
|
389 |
-
|
390 |
-
# Read from config file if provided
|
391 |
-
if args.config:
|
392 |
-
try:
|
393 |
-
print(f"Reading configuration from {args.config}")
|
394 |
-
config = read_config_file(args.config)
|
395 |
-
|
396 |
-
# Update parameters from config
|
397 |
-
num_stages = config.get("num_stages", num_stages)
|
398 |
-
num_batches = config.get("num_batches", num_batches)
|
399 |
-
forward_times = config.get("forward_times")
|
400 |
-
backward_times = config.get("backward_times")
|
401 |
-
output_file = config.get("output_file", output_file)
|
402 |
-
p2p_time = config.get("p2p_time", 0.0)
|
403 |
-
|
404 |
-
except Exception as e:
|
405 |
-
print(f"Error reading config file: {str(e)}")
|
406 |
-
print("Falling back to command line arguments or defaults")
|
407 |
-
|
408 |
-
# Validate inputs
|
409 |
-
if forward_times is None:
|
410 |
-
forward_times = [1.0] * num_stages
|
411 |
-
elif len(forward_times) != num_stages:
|
412 |
-
print(
|
413 |
-
f"Warning: forward_times length ({len(forward_times)}) doesn't match num_stages ({num_stages})"
|
414 |
-
)
|
415 |
-
if len(forward_times) < num_stages:
|
416 |
-
# Extend with repeats of the last value
|
417 |
-
forward_times = list(forward_times) + [forward_times[-1]] * (
|
418 |
-
num_stages - len(forward_times)
|
419 |
-
)
|
420 |
-
else:
|
421 |
-
# Truncate
|
422 |
-
forward_times = forward_times[:num_stages]
|
423 |
-
print(f"Adjusted forward_times: {forward_times}")
|
424 |
-
|
425 |
-
if backward_times is None:
|
426 |
-
backward_times = [2.0] * num_stages
|
427 |
-
elif len(backward_times) != num_stages:
|
428 |
-
print(
|
429 |
-
f"Warning: backward_times length ({len(backward_times)}) doesn't match num_stages ({num_stages})"
|
430 |
-
)
|
431 |
-
if len(backward_times) < num_stages:
|
432 |
-
# Extend with repeats of the last value
|
433 |
-
backward_times = list(backward_times) + [backward_times[-1]] * (
|
434 |
-
num_stages - len(backward_times)
|
435 |
-
)
|
436 |
-
else:
|
437 |
-
# Truncate
|
438 |
-
backward_times = backward_times[:num_stages]
|
439 |
-
print(f"Adjusted backward_times: {backward_times}")
|
440 |
-
|
441 |
-
print(f"Running with parameters:")
|
442 |
-
print(f" num_stages: {num_stages}")
|
443 |
-
print(f" num_batches: {num_batches}")
|
444 |
-
print(f" forward_times: {forward_times}")
|
445 |
-
print(f" backward_times: {backward_times}")
|
446 |
-
print(f" output_file: {output_file}")
|
447 |
-
|
448 |
-
# Create 1F1B schedule
|
449 |
-
schedule = create_1f1b_schedule(
|
450 |
-
num_stages=num_stages,
|
451 |
-
num_batches=num_batches,
|
452 |
-
forward_times=forward_times,
|
453 |
-
backward_times=backward_times,
|
454 |
-
p2p_time=p2p_time,
|
455 |
-
)
|
456 |
-
|
457 |
-
# Create visualization unless --no-visualization is specified
|
458 |
-
if not args.no_visualization:
|
459 |
-
if args.visualizer == "matplotlib" or not DASH_AVAILABLE:
|
460 |
-
if not DASH_AVAILABLE and args.visualizer in ["dash", "dash-interactive"]:
|
461 |
-
print("Warning: Dash not available. Falling back to matplotlib.")
|
462 |
-
visualize_pipeline_parallelism(
|
463 |
-
schedule=schedule, schedule_type="1f1b", output_file=output_file
|
464 |
-
)
|
465 |
-
elif args.visualizer == "dash":
|
466 |
-
# Get output file name without extension to use the appropriate extension
|
467 |
-
output_base = os.path.splitext(output_file)[0]
|
468 |
-
output_dash = f"{output_base}_plotly.png"
|
469 |
-
save_pipeline_visualization_plotly(
|
470 |
-
schedule=schedule, schedule_type="1f1b", output_file=output_dash
|
471 |
-
)
|
472 |
-
elif args.visualizer == "dash-interactive":
|
473 |
-
print("Using Dash interactive visualization")
|
474 |
-
visualize_pipeline_parallelism_dash(
|
475 |
-
schedule=schedule, schedule_type="1f1b", port=8050, debug=False
|
476 |
-
)
|
477 |
-
|
478 |
-
# Analyze the schedule
|
479 |
-
schedule_info = get_schedule_info(schedule)
|
480 |
-
print(schedule_info)
|
481 |
-
|
482 |
-
return {
|
483 |
-
"schedule": schedule,
|
484 |
-
"schedule_info": schedule_info,
|
485 |
-
"num_stages": num_stages,
|
486 |
-
"num_batches": num_batches,
|
487 |
-
}
|
488 |
-
|
489 |
-
|
490 |
-
if __name__ == "__main__":
|
491 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline_1f1b.png
DELETED
Git LFS Details
|
pyproject.toml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["hatchling"]
|
3 |
+
build-backend = "hatchling.build"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "pp-emulation"
|
7 |
+
version = "0.1.0"
|
8 |
+
description = "Pipeline Parallelism Emulation and Visualization"
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.10"
|
11 |
+
authors = [
|
12 |
+
{name = "Project Author"}
|
13 |
+
]
|
14 |
+
classifiers = [
|
15 |
+
"Programming Language :: Python :: 3",
|
16 |
+
"License :: OSI Approved :: MIT License",
|
17 |
+
"Operating System :: OS Independent",
|
18 |
+
]
|
19 |
+
dependencies = [
|
20 |
+
"dash>=2.14.0",
|
21 |
+
"hydra-core>=1.3.2",
|
22 |
+
"omegaconf>=2.3.0",
|
23 |
+
"plotly>=5.18.0",
|
24 |
+
"pandas>=2.1.0",
|
25 |
+
"numpy>=1.26.0",
|
26 |
+
"tqdm>=4.67.0",
|
27 |
+
]
|
28 |
+
|
29 |
+
[project.optional-dependencies]
|
30 |
+
dev = [
|
31 |
+
"pytest>=7.4.0",
|
32 |
+
"black>=23.7.0",
|
33 |
+
"isort>=5.12.0",
|
34 |
+
"mypy>=1.5.1",
|
35 |
+
]
|
36 |
+
|
37 |
+
# Add Hatch configuration to explicitly define where source code is located
|
38 |
+
[tool.hatch.build.targets.wheel]
|
39 |
+
packages = ["src"]
|
40 |
+
|
41 |
+
[tool.hatch.build.targets.sdist]
|
42 |
+
include = [
|
43 |
+
"src",
|
44 |
+
"main.py",
|
45 |
+
"conf",
|
46 |
+
"LICENSE",
|
47 |
+
"README.md",
|
48 |
+
]
|
49 |
+
|
50 |
+
[tool.black]
|
51 |
+
line-length = 88
|
52 |
+
target-version = ["py310"]
|
53 |
+
|
54 |
+
[tool.isort]
|
55 |
+
profile = "black"
|
56 |
+
line_length = 88
|
57 |
+
|
58 |
+
[tool.mypy]
|
59 |
+
python_version = "3.10"
|
60 |
+
warn_return_any = true
|
61 |
+
warn_unused_configs = true
|
62 |
+
disallow_untyped_defs = true
|
63 |
+
disallow_incomplete_defs = true
|
64 |
+
|
65 |
+
[tool.pytest]
|
66 |
+
testpaths = ["tests"]
|
67 |
+
pythonpath = ["."]
|
requirements-dash.txt
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
dash==2.13.0
|
2 |
-
plotly==5.18.0
|
3 |
-
numpy
|
4 |
-
kaleido # For static image export
|
5 |
-
tqdm # For progress bars
|
|
|
|
|
|
|
|
|
|
|
|
src/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""Pipeline Parallelism Emulation and Visualization package."""
|
2 |
+
|
3 |
+
__version__ = "0.1.0"
|
src/execution_model.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import Dict, List, Optional, Union
|
3 |
+
|
4 |
+
|
5 |
+
class Operation:
|
6 |
+
"""Operation is a single operation in the pipeline."""
|
7 |
+
|
8 |
+
def __init__(self, batch_id: int, stage_id: int, op_type: str):
|
9 |
+
self.batch_id = batch_id
|
10 |
+
self.stage_id = stage_id
|
11 |
+
self.op_type = op_type
|
12 |
+
self.device_id = None
|
13 |
+
|
14 |
+
self.start_time = None
|
15 |
+
self.end_time = None
|
16 |
+
|
17 |
+
|
18 |
+
class DeviceQueue:
|
19 |
+
def __init__(self, stages: List[int], device_id: int):
|
20 |
+
self.stages = stages
|
21 |
+
self.device_id = device_id
|
22 |
+
self.ops = [] # List of operations
|
23 |
+
|
24 |
+
def add_operation(self, op: Operation):
|
25 |
+
assert op.stage_id in self.stages
|
26 |
+
self.ops.append(op)
|
27 |
+
assert op.device_id is None
|
28 |
+
op.device_id = self.device_id
|
29 |
+
|
30 |
+
|
31 |
+
class ScheduleConfig:
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
num_devices: int,
|
35 |
+
num_stages: int,
|
36 |
+
num_batches: int,
|
37 |
+
p2p_latency: float = 0.0,
|
38 |
+
placement_strategy: str = "normal",
|
39 |
+
op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
|
40 |
+
):
|
41 |
+
self.num_devices = num_devices
|
42 |
+
self.num_stages = num_stages
|
43 |
+
self.num_batches = num_batches
|
44 |
+
self.p2p_latency = p2p_latency
|
45 |
+
self.placement_strategy = placement_strategy
|
46 |
+
|
47 |
+
# Initialize default operation times
|
48 |
+
self.op_times = {
|
49 |
+
"forward": 1.0,
|
50 |
+
"backward": 2.0,
|
51 |
+
}
|
52 |
+
|
53 |
+
# Update with user-provided operation times
|
54 |
+
if op_times:
|
55 |
+
for op_type, times in op_times.items():
|
56 |
+
if isinstance(times, dict):
|
57 |
+
# If a dict is provided, it maps stage_id -> time
|
58 |
+
if op_type not in self.op_times:
|
59 |
+
self.op_times[op_type] = {}
|
60 |
+
elif not isinstance(self.op_times[op_type], dict):
|
61 |
+
# Convert float to dict if needed
|
62 |
+
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
|
63 |
+
|
64 |
+
# Update with provided stage-specific times
|
65 |
+
for stage_id, time in times.items():
|
66 |
+
if not isinstance(self.op_times[op_type], dict):
|
67 |
+
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
|
68 |
+
self.op_times[op_type][stage_id] = time
|
69 |
+
else:
|
70 |
+
# If a float is provided, use same time for all stages
|
71 |
+
self.op_times[op_type] = times
|
72 |
+
|
73 |
+
assert num_stages % num_devices == 0, "num_stages must be divisible by num_devices"
|
74 |
+
self.num_stages_per_device = num_stages // num_devices
|
75 |
+
|
76 |
+
self.init_device_to_stages()
|
77 |
+
assert (
|
78 |
+
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
|
79 |
+
)
|
80 |
+
|
81 |
+
def init_device_to_stages(self):
|
82 |
+
if self.placement_strategy == "normal":
|
83 |
+
# Evenly distributed
|
84 |
+
stages_per_device = self.num_stages // self.num_devices
|
85 |
+
self.device_to_stages = defaultdict(list)
|
86 |
+
for i in range(self.num_stages):
|
87 |
+
device_to_put = i // stages_per_device
|
88 |
+
self.device_to_stages[device_to_put].append(i)
|
89 |
+
elif self.placement_strategy == "interleave":
|
90 |
+
self.device_to_stages = defaultdict(list)
|
91 |
+
for i in range(self.num_stages):
|
92 |
+
device_to_put = i % self.num_devices
|
93 |
+
self.device_to_stages[device_to_put].append(i)
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
96 |
+
|
97 |
+
def get_op_time(self, op_type: str, stage_id: int):
|
98 |
+
if op_type not in self.op_times:
|
99 |
+
raise ValueError(f"Invalid operation type: {op_type}")
|
100 |
+
|
101 |
+
times = self.op_times[op_type]
|
102 |
+
if isinstance(times, dict):
|
103 |
+
# If we have stage-specific times, use those
|
104 |
+
if stage_id not in times:
|
105 |
+
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
|
106 |
+
return times[stage_id]
|
107 |
+
else:
|
108 |
+
# If we have a single float, use the same value for all stages
|
109 |
+
return times
|
110 |
+
|
111 |
+
|
112 |
+
class Schedule:
|
113 |
+
def __init__(self, config: ScheduleConfig):
|
114 |
+
self.ops = {} # (batch_id, stage_id, op_type) -> Operation
|
115 |
+
self.dev_queues: List[DeviceQueue] = []
|
116 |
+
for dev_id in range(config.num_devices):
|
117 |
+
self.dev_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
|
118 |
+
self.config = config
|
119 |
+
|
120 |
+
self.init_operations()
|
121 |
+
|
122 |
+
def init_operations(self, op_types: Optional[List[str]] = None):
|
123 |
+
if op_types is None:
|
124 |
+
op_types = ["forward", "backward"]
|
125 |
+
for batch_id in range(self.config.num_batches):
|
126 |
+
for stage_id in range(self.config.num_stages):
|
127 |
+
for op_type in op_types:
|
128 |
+
self.ops[(batch_id, stage_id, op_type)] = Operation(
|
129 |
+
batch_id, stage_id, op_type
|
130 |
+
)
|
131 |
+
|
132 |
+
def get_op(self, batch_id: int, stage_id: int, op_type: str):
|
133 |
+
return self.ops[(batch_id, stage_id, op_type)]
|
134 |
+
|
135 |
+
def get_dependencies(self, op: Operation):
|
136 |
+
deps = []
|
137 |
+
if op.op_type == "forward":
|
138 |
+
if op.stage_id > 0:
|
139 |
+
deps.append(
|
140 |
+
(
|
141 |
+
self.get_op(op.batch_id, op.stage_id - 1, "forward"),
|
142 |
+
self.config.p2p_latency,
|
143 |
+
)
|
144 |
+
)
|
145 |
+
elif op.op_type == "backward":
|
146 |
+
if op.stage_id < self.config.num_stages - 1:
|
147 |
+
deps.append(
|
148 |
+
(
|
149 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward"),
|
150 |
+
self.config.p2p_latency,
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
device_index = self.dev_queues[op.device_id].ops.index(op)
|
155 |
+
if device_index > 0:
|
156 |
+
deps.append((self.dev_queues[op.device_id].ops[device_index - 1], 0.0))
|
157 |
+
return deps
|
158 |
+
|
159 |
+
def show(self):
|
160 |
+
"""Display detailed information about the schedule for debugging purposes."""
|
161 |
+
print("\n=== SCHEDULE DETAILS ===")
|
162 |
+
print(f"Devices: {self.config.num_devices}, Stages: {self.config.num_stages}, Batches: {self.config.num_batches}")
|
163 |
+
print(f"Placement Strategy: {self.config.placement_strategy}")
|
164 |
+
print("\n=== DEVICE QUEUES ===")
|
165 |
+
|
166 |
+
for dev_id in range(self.config.num_devices):
|
167 |
+
print(f"\nDEVICE {dev_id} (Stages: {self.dev_queues[dev_id].stages}):")
|
168 |
+
print("-" * 80)
|
169 |
+
print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}")
|
170 |
+
print("-" * 80)
|
171 |
+
|
172 |
+
for op in self.dev_queues[dev_id].ops:
|
173 |
+
op_type = "Forward" if op.op_type == "forward" else "Backward"
|
174 |
+
start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A"
|
175 |
+
end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A"
|
176 |
+
|
177 |
+
duration = "N/A"
|
178 |
+
if op.start_time is not None and op.end_time is not None:
|
179 |
+
duration = f"{op.end_time - op.start_time:.2f}"
|
180 |
+
|
181 |
+
print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}")
|
182 |
+
|
183 |
+
# Find the total execution time (if timing info is available)
|
184 |
+
if all(op.end_time is not None for op in self.ops.values()):
|
185 |
+
total_time = max(op.end_time for op in self.ops.values())
|
186 |
+
print(f"\nTotal execution time: {total_time:.2f}")
|
187 |
+
|
188 |
+
|
189 |
+
class ScheduleExecutor:
|
190 |
+
def __init__(self, schedule: Schedule):
|
191 |
+
self.schedule = schedule
|
192 |
+
|
193 |
+
def execute(self):
|
194 |
+
def execute_op(op: Operation):
|
195 |
+
deps = self.schedule.get_dependencies(op)
|
196 |
+
if len(deps) == 0:
|
197 |
+
op.start_time = 0.0
|
198 |
+
else:
|
199 |
+
for dep, gap in deps:
|
200 |
+
if dep.end_time is None or dep.start_time is None:
|
201 |
+
execute_op(dep)
|
202 |
+
op.start_time = max(dep.end_time + gap for dep, gap in deps)
|
203 |
+
op.end_time = op.start_time + self.schedule.config.get_op_time(
|
204 |
+
op.op_type, op.stage_id
|
205 |
+
)
|
206 |
+
|
207 |
+
op_num = len(self.schedule.dev_queues[0].ops)
|
208 |
+
for i in range(op_num):
|
209 |
+
for dev_id in range(self.schedule.config.num_devices):
|
210 |
+
op = self.schedule.dev_queues[dev_id].ops[i]
|
211 |
+
execute_op(op)
|
212 |
+
|
213 |
+
for op in self.schedule.ops.values():
|
214 |
+
assert (
|
215 |
+
op.start_time is not None
|
216 |
+
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time"
|
217 |
+
assert (
|
218 |
+
op.end_time is not None
|
219 |
+
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time"
|
src/strategies.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from src.execution_model import Schedule, ScheduleConfig
|
3 |
+
|
4 |
+
|
5 |
+
def generate_1f1b_schedule(config: ScheduleConfig):
|
6 |
+
schedule = Schedule(config)
|
7 |
+
|
8 |
+
for i in range(config.num_devices):
|
9 |
+
fwd_batch_id = 0
|
10 |
+
bwd_batch_id = 0
|
11 |
+
cooldown_batches = warmup_batches = config.num_devices - i - 1
|
12 |
+
steady_batches = config.num_batches - warmup_batches
|
13 |
+
|
14 |
+
for _ in range(warmup_batches):
|
15 |
+
for j in range(len(schedule.dev_queues[i].stages)):
|
16 |
+
schedule.dev_queues[i].add_operation(
|
17 |
+
schedule.get_op(fwd_batch_id, schedule.dev_queues[i].stages[j], "forward")
|
18 |
+
)
|
19 |
+
fwd_batch_id += 1
|
20 |
+
|
21 |
+
for _ in range(steady_batches):
|
22 |
+
for j in range(len(schedule.dev_queues[i].stages)):
|
23 |
+
schedule.dev_queues[i].add_operation(
|
24 |
+
schedule.get_op(fwd_batch_id, schedule.dev_queues[i].stages[j], "forward")
|
25 |
+
)
|
26 |
+
fwd_batch_id += 1
|
27 |
+
for j in range(len(schedule.dev_queues[i].stages)-1, -1, -1):
|
28 |
+
schedule.dev_queues[i].add_operation(
|
29 |
+
schedule.get_op(bwd_batch_id, schedule.dev_queues[i].stages[j], "backward")
|
30 |
+
)
|
31 |
+
bwd_batch_id += 1
|
32 |
+
|
33 |
+
for _ in range(cooldown_batches):
|
34 |
+
for j in range(len(schedule.dev_queues[i].stages)-1, -1, -1):
|
35 |
+
schedule.dev_queues[i].add_operation(
|
36 |
+
schedule.get_op(bwd_batch_id, schedule.dev_queues[i].stages[j], "backward")
|
37 |
+
)
|
38 |
+
bwd_batch_id += 1
|
39 |
+
|
40 |
+
return schedule
|
41 |
+
|
42 |
+
|
43 |
+
# Some codes are copied from Megatron-LM
|
44 |
+
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
45 |
+
schedule = Schedule(config)
|
46 |
+
|
47 |
+
def get_pp_rank_microbatches(
|
48 |
+
num_microbatches,
|
49 |
+
num_devices,
|
50 |
+
device_id,
|
51 |
+
num_stages_per_device,
|
52 |
+
microbatch_group_size_per_vp_stage,
|
53 |
+
):
|
54 |
+
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
|
55 |
+
total_num_microbatches = num_microbatches * num_stages_per_device
|
56 |
+
are_all_microbatches_in_warmup = False
|
57 |
+
|
58 |
+
if num_devices > 1:
|
59 |
+
if num_stages_per_device is None:
|
60 |
+
# forward_backward_pipelining_without_interleaving
|
61 |
+
num_warmup_microbatches = num_devices - device_id - 1
|
62 |
+
else:
|
63 |
+
# forward_backward_pipelining_with_interleaving
|
64 |
+
# Run (num_model_chunks-1)*microbatch_group_size_per_vp_stage on
|
65 |
+
# all workers, followed by more microbatches after depending on
|
66 |
+
# stage ID (more forward passes for earlier stages, later stages can
|
67 |
+
# immediately start with 1F1B).
|
68 |
+
num_warmup_microbatches = (num_devices - device_id - 1) * 2
|
69 |
+
num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
|
70 |
+
else:
|
71 |
+
# forward_backward_no_pipelining
|
72 |
+
num_warmup_microbatches = 1
|
73 |
+
|
74 |
+
if num_warmup_microbatches >= total_num_microbatches:
|
75 |
+
num_warmup_microbatches = total_num_microbatches
|
76 |
+
are_all_microbatches_in_warmup = True
|
77 |
+
num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
|
78 |
+
|
79 |
+
return (
|
80 |
+
total_num_microbatches,
|
81 |
+
are_all_microbatches_in_warmup,
|
82 |
+
num_warmup_microbatches,
|
83 |
+
num_microbatches_remaining,
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
def get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
|
88 |
+
"""Get the schedule table for PP scheduling.
|
89 |
+
|
90 |
+
Create a tunable schedule lookup table.
|
91 |
+
The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
|
92 |
+
For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
|
93 |
+
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
94 |
+
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
95 |
+
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
96 |
+
"""
|
97 |
+
schedule_table = []
|
98 |
+
for min_microbatch_id_in_group in range(
|
99 |
+
0, num_microbatches, microbatch_group_size_per_vp_stage
|
100 |
+
):
|
101 |
+
if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
|
102 |
+
# Construct schedule for the last microbatch group
|
103 |
+
schedule_table.extend(
|
104 |
+
[
|
105 |
+
(microbatch_id, model_chunk_id)
|
106 |
+
for model_chunk_id in range(num_model_chunks)
|
107 |
+
for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
|
108 |
+
]
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
# Construct schedule for other microbatch groups
|
112 |
+
schedule_table.extend(
|
113 |
+
[
|
114 |
+
(microbatch_id, model_chunk_id)
|
115 |
+
for model_chunk_id in range(num_model_chunks)
|
116 |
+
for microbatch_id in range(
|
117 |
+
min_microbatch_id_in_group,
|
118 |
+
min_microbatch_id_in_group + microbatch_group_size_per_vp_stage,
|
119 |
+
)
|
120 |
+
]
|
121 |
+
)
|
122 |
+
return schedule_table
|
123 |
+
|
124 |
+
|
125 |
+
def convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
|
126 |
+
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
|
127 |
+
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
|
128 |
+
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
129 |
+
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
130 |
+
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
131 |
+
|
132 |
+
Then the forward backward separated order is:
|
133 |
+
forward | 1 1 1 2 2 2 1 1 2 2
|
134 |
+
backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
|
135 |
+
|
136 |
+
If num_warmup_microbatches is 5, the output order is:
|
137 |
+
1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
|
138 |
+
"""
|
139 |
+
_, model_chunk_id_table = zip(*schedule_table)
|
140 |
+
forward_order = [chunk_id + 1 for chunk_id in model_chunk_id_table]
|
141 |
+
backward_order = [chunk_id - num_model_chunks for chunk_id in model_chunk_id_table]
|
142 |
+
order = forward_order[:num_warmup_microbatches]
|
143 |
+
for i in range(num_warmup_microbatches, len(forward_order)):
|
144 |
+
order.append(forward_order[i])
|
145 |
+
order.append(backward_order[i - num_warmup_microbatches])
|
146 |
+
if num_warmup_microbatches > 0:
|
147 |
+
order.extend(backward_order[-num_warmup_microbatches:])
|
148 |
+
return order
|
149 |
+
|
150 |
+
for device_id in range(config.num_devices):
|
151 |
+
microbatch_group_size_per_vp_stage = config.num_devices
|
152 |
+
total_num_microbatches, are_all_microbatches_in_warmup, num_warmup_microbatches, num_microbatches_remaining = get_pp_rank_microbatches(
|
153 |
+
config.num_batches,
|
154 |
+
config.num_devices,
|
155 |
+
device_id,
|
156 |
+
config.num_stages_per_device,
|
157 |
+
microbatch_group_size_per_vp_stage,
|
158 |
+
)
|
159 |
+
|
160 |
+
schedule_table = get_schedule_table(
|
161 |
+
config.num_batches,
|
162 |
+
config.num_stages_per_device,
|
163 |
+
microbatch_group_size_per_vp_stage,
|
164 |
+
)
|
165 |
+
|
166 |
+
order = convert_schedule_table_to_order(
|
167 |
+
num_warmup_microbatches,
|
168 |
+
num_model_chunks=config.num_stages_per_device,
|
169 |
+
schedule_table=schedule_table,
|
170 |
+
)
|
171 |
+
|
172 |
+
cur_stage_microbatch_id = {}
|
173 |
+
for i in range(1, config.num_stages_per_device+1):
|
174 |
+
cur_stage_microbatch_id[i] = 0
|
175 |
+
cur_stage_microbatch_id[-i] = 0
|
176 |
+
for order_item in order:
|
177 |
+
stage_id = schedule.dev_queues[device_id].stages[abs(order_item)-1]
|
178 |
+
|
179 |
+
if order_item > 0:
|
180 |
+
op_type = "forward"
|
181 |
+
micro_batch_id = cur_stage_microbatch_id[order_item]
|
182 |
+
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
|
183 |
+
elif order_item < 0:
|
184 |
+
op_type = "backward"
|
185 |
+
micro_batch_id = cur_stage_microbatch_id[order_item]
|
186 |
+
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
|
187 |
+
else:
|
188 |
+
raise ValueError(f"Invalid order item: {order_item}")
|
189 |
+
schedule.dev_queues[device_id].add_operation(
|
190 |
+
schedule.get_op(micro_batch_id, stage_id, op_type)
|
191 |
+
)
|
192 |
+
return schedule
|
dash_visualizer.py → src/visualizer.py
RENAMED
@@ -1,41 +1,86 @@
|
|
1 |
import dash
|
2 |
from dash import dcc, html
|
3 |
-
from dash.dependencies import Input, Output
|
4 |
import plotly.graph_objects as go
|
5 |
-
import
|
6 |
-
from typing import List, Dict, Literal
|
7 |
from tqdm import tqdm
|
8 |
-
import
|
9 |
|
|
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
Create a Plotly figure for pipeline parallelism scheduling.
|
14 |
|
15 |
Args:
|
16 |
-
|
17 |
-
Each task is a dictionary with keys:
|
18 |
-
- 'type': 'forward', 'backward', or 'optimizer'
|
19 |
-
- 'batch': batch number
|
20 |
-
- 'start_time': start time of the task
|
21 |
-
- 'duration': duration of the task
|
22 |
max_time: Optional maximum time to display
|
23 |
show_progress: Whether to show a progress bar
|
24 |
"""
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
optimizer_color = "#FFEFCF"
|
29 |
empty_color = "whitesmoke"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
# Find the maximum time in the schedule if not provided
|
35 |
if max_time is None:
|
36 |
max_time = 0
|
37 |
-
for device in
|
38 |
-
for task in
|
39 |
end_time = task["start_time"] + task["duration"]
|
40 |
if end_time > max_time:
|
41 |
max_time = end_time
|
@@ -44,56 +89,51 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
|
|
44 |
fig = go.Figure()
|
45 |
|
46 |
# Initialize progress tracking
|
47 |
-
total_tasks = sum(len(tasks) for tasks in
|
48 |
tasks_processed = 0
|
49 |
|
50 |
if show_progress:
|
51 |
-
progress_bar = tqdm(total=total_tasks +
|
52 |
|
53 |
-
#
|
54 |
-
for
|
55 |
-
device_idx_reversed = num_stages - device_idx - 1 # Reverse for plotting
|
56 |
-
fig.add_trace(go.Scatter(
|
57 |
-
x=[0, max_time],
|
58 |
-
y=[device_idx_reversed, device_idx_reversed],
|
59 |
-
mode='lines',
|
60 |
-
line=dict(color='lightgray', width=0.5),
|
61 |
-
showlegend=False,
|
62 |
-
hoverinfo='none'
|
63 |
-
))
|
64 |
-
if show_progress:
|
65 |
-
progress_bar.update(1)
|
66 |
|
67 |
# Add rectangles for each task
|
68 |
-
for device_idx, device in enumerate(
|
69 |
-
device_idx_reversed =
|
|
|
|
|
|
|
70 |
|
71 |
-
for task in
|
72 |
# Determine task color and text color
|
73 |
if task["type"] == "forward":
|
74 |
-
color =
|
75 |
text_color = "white"
|
76 |
name = "Forward"
|
77 |
elif task["type"] == "backward":
|
78 |
-
color =
|
79 |
text_color = "black"
|
80 |
name = "Backward"
|
81 |
-
else:
|
82 |
-
color =
|
83 |
text_color = "black"
|
84 |
-
name = "
|
85 |
-
|
86 |
# Add rectangle for the task
|
87 |
start_time = task["start_time"]
|
88 |
duration = task["duration"]
|
89 |
|
|
|
|
|
|
|
90 |
# Create rectangle using shape
|
91 |
fig.add_shape(
|
92 |
type="rect",
|
93 |
x0=start_time,
|
94 |
-
y0=
|
95 |
x1=start_time + duration,
|
96 |
-
y1=
|
97 |
line=dict(color="black", width=0.5),
|
98 |
fillcolor=color,
|
99 |
layer="above",
|
@@ -102,12 +142,23 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
|
|
102 |
# Add batch number text
|
103 |
fig.add_annotation(
|
104 |
x=start_time + duration / 2,
|
105 |
-
y=
|
106 |
-
text=
|
107 |
showarrow=False,
|
108 |
-
font=dict(color=text_color, size=
|
109 |
)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# Update progress
|
112 |
if show_progress:
|
113 |
tasks_processed += 1
|
@@ -115,9 +166,8 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
|
|
115 |
|
116 |
# Add custom legend
|
117 |
legend_items = [
|
118 |
-
dict(name="Forward", color=
|
119 |
-
dict(name="Backward", color=
|
120 |
-
dict(name="Optimizer step", color=optimizer_color)
|
121 |
]
|
122 |
|
123 |
for i, item in enumerate(legend_items):
|
@@ -133,77 +183,98 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
|
|
133 |
progress_bar.update(1)
|
134 |
|
135 |
# Set axis properties
|
136 |
-
device_labels = [f"Device {i
|
137 |
-
device_labels.reverse() # Reverse to put Device
|
|
|
|
|
|
|
138 |
|
|
|
|
|
|
|
139 |
fig.update_layout(
|
140 |
-
xaxis=dict(
|
141 |
-
showticklabels=False,
|
142 |
-
showgrid=False,
|
143 |
-
zeroline=False,
|
144 |
-
title="Time →",
|
145 |
-
range=[0, max_time + 0.5]
|
146 |
-
),
|
147 |
yaxis=dict(
|
148 |
tickmode="array",
|
149 |
-
tickvals=
|
150 |
ticktext=device_labels,
|
151 |
showgrid=False,
|
152 |
zeroline=False,
|
153 |
-
range=[-0.5, num_stages - 0.5]
|
154 |
),
|
155 |
-
margin=dict(l=50, r=
|
156 |
plot_bgcolor="white",
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
legend=dict(
|
158 |
orientation="h",
|
159 |
-
yanchor="
|
160 |
-
y=-0.
|
161 |
xanchor="center",
|
162 |
x=0.5
|
163 |
-
)
|
|
|
|
|
|
|
|
|
164 |
)
|
165 |
|
166 |
if show_progress:
|
167 |
-
progress_bar.update(1)
|
168 |
progress_bar.close()
|
169 |
|
170 |
return fig
|
171 |
|
172 |
|
173 |
-
def create_dash_app(schedule:
|
174 |
"""
|
175 |
-
Create a Dash app
|
176 |
-
|
177 |
Args:
|
178 |
-
schedule:
|
179 |
-
schedule_type: Type of
|
180 |
"""
|
181 |
-
|
|
|
|
|
|
|
|
|
182 |
|
183 |
app.layout = html.Div([
|
184 |
-
html.H1(f"Pipeline Parallelism
|
185 |
-
style={'textAlign': 'center'}),
|
186 |
-
|
187 |
-
html.Div(id="loading-container", children=[
|
188 |
-
dcc.Loading(
|
189 |
-
id="loading-graph",
|
190 |
-
type="circle",
|
191 |
-
children=[
|
192 |
-
html.Div(id="graph-container", children=[
|
193 |
-
dcc.Graph(
|
194 |
-
id='pipeline-graph',
|
195 |
-
style={'height': '600px'}
|
196 |
-
)
|
197 |
-
])
|
198 |
-
]
|
199 |
-
)
|
200 |
-
]),
|
201 |
|
202 |
html.Div([
|
203 |
-
html.
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
])
|
208 |
|
209 |
@app.callback(
|
@@ -213,98 +284,65 @@ def create_dash_app(schedule: Dict[int, List[Dict]], schedule_type="1f1b"):
|
|
213 |
)
|
214 |
def load_graph(_):
|
215 |
# Create the figure when the app loads
|
216 |
-
return create_pipeline_figure(
|
217 |
-
|
218 |
@app.callback(
|
219 |
Output("download-image", "data"),
|
220 |
Input("btn-download", "n_clicks"),
|
221 |
prevent_initial_call=True,
|
222 |
)
|
223 |
def download_image(n_clicks):
|
224 |
-
#
|
225 |
-
fig = create_pipeline_figure(
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
227 |
return dict(
|
228 |
-
content=
|
229 |
-
filename="
|
|
|
|
|
230 |
)
|
231 |
|
232 |
return app
|
233 |
|
234 |
|
235 |
def visualize_pipeline_parallelism_dash(
|
236 |
-
schedule:
|
237 |
-
schedule_type: Literal["simple", "1f1b"] = "1f1b",
|
238 |
port: int = 8050,
|
239 |
debug: bool = False
|
240 |
):
|
241 |
"""
|
242 |
-
|
243 |
-
|
244 |
Args:
|
245 |
-
schedule:
|
246 |
-
|
247 |
-
|
248 |
-
debug: Whether to run the app in debug mode
|
249 |
"""
|
250 |
-
app = create_dash_app(schedule
|
251 |
print(f"Starting Dash app on http://localhost:{port}/")
|
252 |
app.run_server(debug=debug, port=port)
|
253 |
|
254 |
|
255 |
def save_pipeline_visualization_plotly(
|
256 |
-
schedule:
|
257 |
-
schedule_type: Literal["simple", "1f1b"] = "1f1b",
|
258 |
output_file: str = "pipeline_visualization_plotly.png",
|
259 |
):
|
260 |
"""
|
261 |
-
Save a static
|
262 |
-
|
263 |
Args:
|
264 |
-
schedule:
|
265 |
-
|
266 |
-
output_file: Path to save the visualization
|
267 |
"""
|
268 |
-
|
269 |
-
fig = create_pipeline_figure(
|
270 |
-
|
271 |
-
# Update layout for static image
|
272 |
-
fig.update_layout(
|
273 |
-
title=f"Pipeline Parallelism Visualization ({schedule_type.upper()})",
|
274 |
-
title_x=0.5
|
275 |
-
)
|
276 |
|
277 |
-
print(f"Saving
|
278 |
-
|
279 |
-
fig.write_image(output_file, scale=3)
|
280 |
print(f"Visualization saved to {output_file}")
|
281 |
|
282 |
-
|
283 |
-
if __name__ == "__main__":
|
284 |
-
# Example usage
|
285 |
-
import argparse
|
286 |
-
from pipeline import create_1f1b_schedule
|
287 |
-
|
288 |
-
parser = argparse.ArgumentParser(description="Pipeline Parallelism Visualizer")
|
289 |
-
parser.add_argument("--num-stages", type=int, default=4, help="Number of pipeline stages")
|
290 |
-
parser.add_argument("--num-batches", type=int, default=8, help="Number of microbatches")
|
291 |
-
parser.add_argument("--interactive", action="store_true", help="Run interactive Dash app")
|
292 |
-
parser.add_argument("--port", type=int, default=8050, help="Port for Dash app")
|
293 |
-
parser.add_argument("--output", type=str, default="pipeline_visualization_plotly.png", help="Output file for static image")
|
294 |
-
args = parser.parse_args()
|
295 |
-
|
296 |
-
# Create an example schedule
|
297 |
-
forward_times = [1.0] * args.num_stages
|
298 |
-
backward_times = [2.0] * args.num_stages
|
299 |
-
|
300 |
-
schedule = create_1f1b_schedule(
|
301 |
-
num_stages=args.num_stages,
|
302 |
-
num_batches=args.num_batches,
|
303 |
-
forward_times=forward_times,
|
304 |
-
backward_times=backward_times,
|
305 |
-
)
|
306 |
-
|
307 |
-
if args.interactive:
|
308 |
-
visualize_pipeline_parallelism_dash(schedule, port=args.port)
|
309 |
-
else:
|
310 |
-
save_pipeline_visualization_plotly(schedule, output_file=args.output)
|
|
|
1 |
import dash
|
2 |
from dash import dcc, html
|
3 |
+
from dash.dependencies import Input, Output
|
4 |
import plotly.graph_objects as go
|
5 |
+
import argparse
|
6 |
+
from typing import List, Dict, Literal, Optional
|
7 |
from tqdm import tqdm
|
8 |
+
import base64
|
9 |
|
10 |
+
from src.execution_model import Schedule
|
11 |
|
12 |
+
|
13 |
+
def convert_schedule_to_visualization_format(schedule: Schedule):
|
14 |
+
"""
|
15 |
+
Converts a Schedule object to the format needed for visualization.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
|
19 |
+
"""
|
20 |
+
# Make sure all operations have start and end times
|
21 |
+
for op in schedule.ops.values():
|
22 |
+
if op.start_time is None or op.end_time is None:
|
23 |
+
raise ValueError("Operations must have start and end times. Run ScheduleExecutor.execute() first.")
|
24 |
+
|
25 |
+
visualization_data = {}
|
26 |
+
|
27 |
+
# Organize operations by device
|
28 |
+
for device_id, device_queue in enumerate(schedule.dev_queues):
|
29 |
+
visualization_data[device_id] = []
|
30 |
+
|
31 |
+
for op in device_queue.ops:
|
32 |
+
visualization_data[device_id].append({
|
33 |
+
"type": op.op_type,
|
34 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
35 |
+
"stage": op.stage_id,
|
36 |
+
"start_time": op.start_time,
|
37 |
+
"duration": op.end_time - op.start_time
|
38 |
+
})
|
39 |
+
|
40 |
+
return visualization_data
|
41 |
+
|
42 |
+
|
43 |
+
def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True):
|
44 |
"""
|
45 |
Create a Plotly figure for pipeline parallelism scheduling.
|
46 |
|
47 |
Args:
|
48 |
+
schedule_data: Dictionary mapping device IDs to lists of tasks (converted from Schedule)
|
|
|
|
|
|
|
|
|
|
|
49 |
max_time: Optional maximum time to display
|
50 |
show_progress: Whether to show a progress bar
|
51 |
"""
|
52 |
+
# Find the number of devices
|
53 |
+
num_devices = len(schedule_data)
|
54 |
+
|
|
|
55 |
empty_color = "whitesmoke"
|
56 |
+
# Colors for task types
|
57 |
+
def get_color(op_type: str, stage_id: int):
|
58 |
+
# Base colors
|
59 |
+
forward_base_color = "royalblue"
|
60 |
+
backward_base_color = "lightgreen" # Changed from sandybrown to match your visualization
|
61 |
+
|
62 |
+
virtual_stage = stage_id // num_devices
|
63 |
|
64 |
+
if op_type == "forward":
|
65 |
+
if virtual_stage == 0:
|
66 |
+
return forward_base_color
|
67 |
+
else:
|
68 |
+
# Lighter shade for virtual_stage > 0
|
69 |
+
return "lightskyblue"
|
70 |
+
elif op_type == "backward":
|
71 |
+
if virtual_stage == 0:
|
72 |
+
return backward_base_color
|
73 |
+
else:
|
74 |
+
# Lighter shade for virtual_stage > 0
|
75 |
+
return "lightseagreen"
|
76 |
+
else:
|
77 |
+
raise ValueError(f"Invalid operation type: {op_type}")
|
78 |
|
79 |
# Find the maximum time in the schedule if not provided
|
80 |
if max_time is None:
|
81 |
max_time = 0
|
82 |
+
for device in schedule_data:
|
83 |
+
for task in schedule_data[device]:
|
84 |
end_time = task["start_time"] + task["duration"]
|
85 |
if end_time > max_time:
|
86 |
max_time = end_time
|
|
|
89 |
fig = go.Figure()
|
90 |
|
91 |
# Initialize progress tracking
|
92 |
+
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
|
93 |
tasks_processed = 0
|
94 |
|
95 |
if show_progress:
|
96 |
+
progress_bar = tqdm(total=total_tasks + num_devices + 3, desc="Creating visualization")
|
97 |
|
98 |
+
# Create a custom y-axis with no gaps between devices
|
99 |
+
y_spacing = 1.0 # Use 1.0 for no gaps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# Add rectangles for each task
|
102 |
+
for device_idx, device in enumerate(schedule_data):
|
103 |
+
device_idx_reversed = num_devices - device_idx - 1
|
104 |
+
|
105 |
+
# Sort tasks by start time to ensure correct rendering
|
106 |
+
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
107 |
|
108 |
+
for task in sorted_tasks:
|
109 |
# Determine task color and text color
|
110 |
if task["type"] == "forward":
|
111 |
+
color = get_color(task["type"], task["stage"])
|
112 |
text_color = "white"
|
113 |
name = "Forward"
|
114 |
elif task["type"] == "backward":
|
115 |
+
color = get_color(task["type"], task["stage"])
|
116 |
text_color = "black"
|
117 |
name = "Backward"
|
118 |
+
else:
|
119 |
+
color = empty_color
|
120 |
text_color = "black"
|
121 |
+
name = "Unknown"
|
122 |
+
|
123 |
# Add rectangle for the task
|
124 |
start_time = task["start_time"]
|
125 |
duration = task["duration"]
|
126 |
|
127 |
+
# Calculate y positions with no gaps
|
128 |
+
y_pos = device_idx_reversed * y_spacing
|
129 |
+
|
130 |
# Create rectangle using shape
|
131 |
fig.add_shape(
|
132 |
type="rect",
|
133 |
x0=start_time,
|
134 |
+
y0=y_pos - 0.5,
|
135 |
x1=start_time + duration,
|
136 |
+
y1=y_pos + 0.5,
|
137 |
line=dict(color="black", width=0.5),
|
138 |
fillcolor=color,
|
139 |
layer="above",
|
|
|
142 |
# Add batch number text
|
143 |
fig.add_annotation(
|
144 |
x=start_time + duration / 2,
|
145 |
+
y=y_pos,
|
146 |
+
text=f"{task['batch']}", # Only show batch ID
|
147 |
showarrow=False,
|
148 |
+
font=dict(color=text_color, size=12, family="Arial, bold"), # Increased font size
|
149 |
)
|
150 |
|
151 |
+
# Add hover data with additional details
|
152 |
+
fig.add_trace(go.Scatter(
|
153 |
+
x=[start_time + duration / 2],
|
154 |
+
y=[y_pos],
|
155 |
+
mode='markers',
|
156 |
+
marker=dict(opacity=0), # Invisible marker
|
157 |
+
hoverinfo='text',
|
158 |
+
text=f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}",
|
159 |
+
showlegend=False
|
160 |
+
))
|
161 |
+
|
162 |
# Update progress
|
163 |
if show_progress:
|
164 |
tasks_processed += 1
|
|
|
166 |
|
167 |
# Add custom legend
|
168 |
legend_items = [
|
169 |
+
dict(name="Forward", color=get_color("forward", 0)),
|
170 |
+
dict(name="Backward", color=get_color("backward", 0)),
|
|
|
171 |
]
|
172 |
|
173 |
for i, item in enumerate(legend_items):
|
|
|
183 |
progress_bar.update(1)
|
184 |
|
185 |
# Set axis properties
|
186 |
+
device_labels = [f"Device {i}" for i in range(num_devices)]
|
187 |
+
device_labels.reverse() # Reverse to put Device 0 at the top
|
188 |
+
|
189 |
+
# Calculate tick positions with no gaps
|
190 |
+
tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
|
191 |
|
192 |
+
# Adjust the range to ensure there are no empty spaces at the end
|
193 |
+
x_end = max_time * 1.05 # Add a small margin
|
194 |
+
|
195 |
fig.update_layout(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
yaxis=dict(
|
197 |
tickmode="array",
|
198 |
+
tickvals=tick_positions,
|
199 |
ticktext=device_labels,
|
200 |
showgrid=False,
|
201 |
zeroline=False,
|
|
|
202 |
),
|
203 |
+
margin=dict(l=50, r=20, t=40, b=40),
|
204 |
plot_bgcolor="white",
|
205 |
+
title=dict(
|
206 |
+
text="Pipeline Parallelism Schedule",
|
207 |
+
x=0.5,
|
208 |
+
y=0.98, # Move title position closer to the top
|
209 |
+
font=dict(size=20)
|
210 |
+
),
|
211 |
legend=dict(
|
212 |
orientation="h",
|
213 |
+
yanchor="top",
|
214 |
+
y=-0.1, # Position below the plot
|
215 |
xanchor="center",
|
216 |
x=0.5
|
217 |
+
),
|
218 |
+
width=1600,
|
219 |
+
height=400, # Reduce height to make the visualization more compact
|
220 |
+
bargap=0,
|
221 |
+
bargroupgap=0,
|
222 |
)
|
223 |
|
224 |
if show_progress:
|
225 |
+
progress_bar.update(1)
|
226 |
progress_bar.close()
|
227 |
|
228 |
return fig
|
229 |
|
230 |
|
231 |
+
def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
|
232 |
"""
|
233 |
+
Create a Dash app to visualize the pipeline schedule.
|
234 |
+
|
235 |
Args:
|
236 |
+
schedule: Schedule object to visualize
|
237 |
+
schedule_type: Type of schedule ("1f1b" or other)
|
238 |
"""
|
239 |
+
# Convert schedule to visualization format
|
240 |
+
schedule_data = convert_schedule_to_visualization_format(schedule)
|
241 |
+
|
242 |
+
# Create the app
|
243 |
+
app = dash.Dash(__name__, title=f"Pipeline Parallelism Visualizer - {schedule_type}")
|
244 |
|
245 |
app.layout = html.Div([
|
246 |
+
html.H1(f"Pipeline Parallelism Visualizer - {schedule_type}", style={'textAlign': 'center'}),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
html.Div([
|
249 |
+
html.Div([
|
250 |
+
html.H3("Schedule Configuration:"),
|
251 |
+
html.Ul([
|
252 |
+
html.Li(f"Number of devices: {schedule.config.num_devices}"),
|
253 |
+
html.Li(f"Number of stages: {schedule.config.num_stages}"),
|
254 |
+
html.Li(f"Number of batches: {schedule.config.num_batches}"),
|
255 |
+
]),
|
256 |
+
], className="config-section"),
|
257 |
+
|
258 |
+
html.Button("Download Image", id="btn-download",
|
259 |
+
style={
|
260 |
+
'marginTop': '20px',
|
261 |
+
'padding': '10px',
|
262 |
+
'backgroundColor': '#007BFF',
|
263 |
+
'color': 'white',
|
264 |
+
'border': 'none',
|
265 |
+
'borderRadius': '5px',
|
266 |
+
'cursor': 'pointer'
|
267 |
+
}),
|
268 |
+
|
269 |
+
dcc.Download(id="download-image"),
|
270 |
+
], style={'margin': '20px'}),
|
271 |
+
|
272 |
+
html.Div(id="graph-container", children=[]),
|
273 |
+
|
274 |
+
dcc.Graph(
|
275 |
+
id="pipeline-graph",
|
276 |
+
config={'displayModeBar': True, 'toImageButtonOptions': {'format': 'png', 'filename': 'pipeline_visualization'}}
|
277 |
+
),
|
278 |
])
|
279 |
|
280 |
@app.callback(
|
|
|
284 |
)
|
285 |
def load_graph(_):
|
286 |
# Create the figure when the app loads
|
287 |
+
return create_pipeline_figure(schedule_data, show_progress=True)
|
288 |
+
|
289 |
@app.callback(
|
290 |
Output("download-image", "data"),
|
291 |
Input("btn-download", "n_clicks"),
|
292 |
prevent_initial_call=True,
|
293 |
)
|
294 |
def download_image(n_clicks):
|
295 |
+
# Generate the figure for download
|
296 |
+
fig = create_pipeline_figure(schedule_data, show_progress=True)
|
297 |
+
|
298 |
+
# Convert to base64 image
|
299 |
+
img_bytes = fig.to_image(format="png", width=1600, height=1000, scale=2)
|
300 |
+
img_base64 = base64.b64encode(img_bytes).decode('ascii')
|
301 |
+
|
302 |
+
# Return the download data
|
303 |
return dict(
|
304 |
+
content=img_base64,
|
305 |
+
filename=f"pipeline_visualization_{schedule_type}.png",
|
306 |
+
type="image/png",
|
307 |
+
base64=True
|
308 |
)
|
309 |
|
310 |
return app
|
311 |
|
312 |
|
313 |
def visualize_pipeline_parallelism_dash(
|
314 |
+
schedule: Schedule,
|
|
|
315 |
port: int = 8050,
|
316 |
debug: bool = False
|
317 |
):
|
318 |
"""
|
319 |
+
Launch a Dash app to visualize the pipeline schedule interactively.
|
320 |
+
|
321 |
Args:
|
322 |
+
schedule: Schedule object to visualize
|
323 |
+
port: Port to run the Dash app on
|
324 |
+
debug: Whether to run the Dash app in debug mode
|
|
|
325 |
"""
|
326 |
+
app = create_dash_app(schedule)
|
327 |
print(f"Starting Dash app on http://localhost:{port}/")
|
328 |
app.run_server(debug=debug, port=port)
|
329 |
|
330 |
|
331 |
def save_pipeline_visualization_plotly(
|
332 |
+
schedule: Schedule,
|
|
|
333 |
output_file: str = "pipeline_visualization_plotly.png",
|
334 |
):
|
335 |
"""
|
336 |
+
Save a static image of the pipeline schedule visualization.
|
337 |
+
|
338 |
Args:
|
339 |
+
schedule: Schedule object to visualize
|
340 |
+
output_file: Path to save the image to
|
|
|
341 |
"""
|
342 |
+
schedule_data = convert_schedule_to_visualization_format(schedule)
|
343 |
+
fig = create_pipeline_figure(schedule_data, show_progress=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
+
print(f"Saving visualization to {output_file}...")
|
346 |
+
fig.write_image(output_file, width=1600, height=400, scale=2)
|
|
|
347 |
print(f"Visualization saved to {output_file}")
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
visualizer.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
-
import numpy as np
|
3 |
-
from matplotlib.patches import Rectangle
|
4 |
-
from typing import List, Dict, Literal
|
5 |
-
|
6 |
-
|
7 |
-
def visualize_pipeline_parallelism(
|
8 |
-
schedule: Dict[int, List[Dict]],
|
9 |
-
schedule_type: Literal["simple", "1f1b"] = "1f1b",
|
10 |
-
output_file: str = "pipeline_visualization.png",
|
11 |
-
):
|
12 |
-
"""
|
13 |
-
Visualize pipeline parallelism scheduling.
|
14 |
-
|
15 |
-
Args:
|
16 |
-
schedule: Dictionary mapping device IDs to lists of tasks.
|
17 |
-
Each task is a dictionary with keys:
|
18 |
-
- 'type': 'forward', 'backward', or 'optimizer'
|
19 |
-
- 'batch': batch number
|
20 |
-
- 'start_time': start time of the task
|
21 |
-
- 'duration': duration of the task
|
22 |
-
schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
|
23 |
-
output_file: Path to save the visualization
|
24 |
-
"""
|
25 |
-
# Colors for task types
|
26 |
-
forward_color = "royalblue"
|
27 |
-
backward_color = "sandybrown" # Changed to match the reference image
|
28 |
-
optimizer_color = "#FFEFCF" # Light beige for optimizer steps
|
29 |
-
empty_color = "whitesmoke" # Very light gray for empty cells
|
30 |
-
|
31 |
-
# Find the number of stages (devices)
|
32 |
-
num_stages = len(schedule)
|
33 |
-
|
34 |
-
# Find the maximum time in the schedule
|
35 |
-
max_time = 0
|
36 |
-
for device in schedule:
|
37 |
-
for task in schedule[device]:
|
38 |
-
end_time = task["start_time"] + task["duration"]
|
39 |
-
if end_time > max_time:
|
40 |
-
max_time = end_time
|
41 |
-
|
42 |
-
# Create figure and axis
|
43 |
-
fig, ax = plt.subplots(figsize=(15, 4))
|
44 |
-
|
45 |
-
# Create an empty grid with light gray color
|
46 |
-
for device_idx in range(num_stages):
|
47 |
-
device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
|
48 |
-
for t in range(int(max_time) + 1):
|
49 |
-
rect = Rectangle(
|
50 |
-
(t, device_idx_reversed),
|
51 |
-
1.0,
|
52 |
-
1.0,
|
53 |
-
edgecolor="lightgray",
|
54 |
-
facecolor=empty_color,
|
55 |
-
linewidth=0.5,
|
56 |
-
)
|
57 |
-
ax.add_patch(rect)
|
58 |
-
|
59 |
-
# Plot the schedule
|
60 |
-
for device_idx, device in enumerate(schedule):
|
61 |
-
device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
|
62 |
-
for task in schedule[device]:
|
63 |
-
# Determine task color
|
64 |
-
if task["type"] == "forward":
|
65 |
-
color = forward_color
|
66 |
-
text_color = "white"
|
67 |
-
elif task["type"] == "backward":
|
68 |
-
color = backward_color
|
69 |
-
text_color = "black"
|
70 |
-
else: # optimizer or any other type
|
71 |
-
color = optimizer_color
|
72 |
-
text_color = "black"
|
73 |
-
|
74 |
-
rect = Rectangle(
|
75 |
-
(task["start_time"], device_idx_reversed),
|
76 |
-
task["duration"],
|
77 |
-
1.0,
|
78 |
-
edgecolor="black",
|
79 |
-
facecolor=color,
|
80 |
-
linewidth=0.5,
|
81 |
-
)
|
82 |
-
ax.add_patch(rect)
|
83 |
-
|
84 |
-
# Add text (batch number)
|
85 |
-
ax.text(
|
86 |
-
task["start_time"] + task["duration"] / 2,
|
87 |
-
device_idx_reversed + 0.5,
|
88 |
-
str(task["batch"]),
|
89 |
-
ha="center",
|
90 |
-
va="center",
|
91 |
-
fontsize=10,
|
92 |
-
fontweight="bold",
|
93 |
-
color=text_color,
|
94 |
-
)
|
95 |
-
|
96 |
-
# Set axis limits and labels
|
97 |
-
ax.set_xlim(0, max_time + 0.5)
|
98 |
-
ax.set_ylim(-0.5, num_stages + 0.5)
|
99 |
-
ax.set_yticks(np.arange(num_stages) + 0.5)
|
100 |
-
|
101 |
-
# Reverse the order: Device 1 at the top, highest number at the bottom
|
102 |
-
device_labels = [f"Device {i+1}" for i in range(num_stages)]
|
103 |
-
device_labels.reverse() # Reverse to put Device 1 at the top
|
104 |
-
ax.set_yticklabels(device_labels)
|
105 |
-
|
106 |
-
# Add "Time" label and arrow at the bottom
|
107 |
-
arrow_y = -0.4
|
108 |
-
ax.text(0.5, arrow_y, "Time", ha="right", va="center", fontsize=10)
|
109 |
-
ax.annotate("", xy=(2, arrow_y), xytext=(1, arrow_y),
|
110 |
-
arrowprops=dict(arrowstyle="->", lw=1))
|
111 |
-
|
112 |
-
# Remove the x-axis ticks
|
113 |
-
ax.set_xticks([])
|
114 |
-
|
115 |
-
# Remove the outer frame/border
|
116 |
-
for spine in ax.spines.values():
|
117 |
-
spine.set_visible(False)
|
118 |
-
|
119 |
-
# Add a legend - using 3 parts like in the reference image
|
120 |
-
forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
|
121 |
-
backward_patch = Rectangle((0, 0), 1, 1, facecolor=backward_color)
|
122 |
-
optimizer_patch = Rectangle((0, 0), 1, 1, facecolor=optimizer_color)
|
123 |
-
|
124 |
-
legend = ax.legend(
|
125 |
-
[forward_patch, backward_patch, optimizer_patch],
|
126 |
-
["Forward", "Backward", "Optimizer step"],
|
127 |
-
loc="upper center",
|
128 |
-
bbox_to_anchor=(0.5, -0.15),
|
129 |
-
ncol=3,
|
130 |
-
frameon=False,
|
131 |
-
)
|
132 |
-
|
133 |
-
# Turn off grid
|
134 |
-
ax.grid(False)
|
135 |
-
|
136 |
-
# Save the figure
|
137 |
-
plt.tight_layout()
|
138 |
-
plt.savefig(output_file, dpi=300, bbox_inches="tight")
|
139 |
-
plt.close()
|
140 |
-
|
141 |
-
print(f"Visualization saved to {output_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|