Spaces:
Running
Running
Add interactive server.
Browse files- README.md +15 -2
- pyproject.toml +1 -1
- src/server.py +224 -0
README.md
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Pipeline Parallelism Emulation
|
2 |
|
3 |
This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
|
4 |
|
@@ -37,7 +37,20 @@ Setup `uv` if not installed on your computer:
|
|
37 |
curl -LsSf https://astral.sh/uv/install.sh | sh
|
38 |
```
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
### Running for 1F1B strategy:
|
43 |
```bash
|
|
|
1 |
+
# Pipeline Parallelism Emulation and Visualization
|
2 |
|
3 |
This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
|
4 |
|
|
|
37 |
curl -LsSf https://astral.sh/uv/install.sh | sh
|
38 |
```
|
39 |
|
40 |
+
|
41 |
+
## Running the Interactive Server
|
42 |
+
|
43 |
+
To visualize schedules interactively:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
uv run src/server.py
|
47 |
+
```
|
48 |
+
|
49 |
+
This will start a Dash server (usually on `http://127.0.0.1:8050/`). Open this URL in your web browser.
|
50 |
+
|
51 |
+
You can then adjust parameters like the number of devices, stages, batches, operation times, and select different scheduling strategies to see the resulting pipeline visualization.
|
52 |
+
|
53 |
+
## Running from Command Line
|
54 |
|
55 |
### Running for 1F1B strategy:
|
56 |
```bash
|
pyproject.toml
CHANGED
@@ -9,7 +9,7 @@ description = "Pipeline Parallelism Emulation and Visualization"
|
|
9 |
readme = "README.md"
|
10 |
requires-python = ">=3.10"
|
11 |
authors = [
|
12 |
-
{name = "
|
13 |
]
|
14 |
classifiers = [
|
15 |
"Programming Language :: Python :: 3",
|
|
|
9 |
readme = "README.md"
|
10 |
requires-python = ">=3.10"
|
11 |
authors = [
|
12 |
+
{name = "Zhenhuan Liu"}
|
13 |
]
|
14 |
classifiers = [
|
15 |
"Programming Language :: Python :: 3",
|
src/server.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dash
|
2 |
+
from dash import dcc, html, Input, Output, State, callback_context
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import webbrowser
|
5 |
+
from threading import Timer
|
6 |
+
|
7 |
+
from src.execution_model import ScheduleConfig, Schedule
|
8 |
+
from src.strategies import (
|
9 |
+
generate_1f1b_schedule,
|
10 |
+
generate_zero_bubble_1p_schedule,
|
11 |
+
generate_1f1b_overlap_schedule,
|
12 |
+
generate_1f1b_interleave_schedule,
|
13 |
+
generate_1f1b_interleave_overlap_schedule,
|
14 |
+
generate_dualpipe_schedule
|
15 |
+
)
|
16 |
+
from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure
|
17 |
+
|
18 |
+
def open_browser(port):
|
19 |
+
webbrowser.open_new(f"http://127.0.0.1:{port}")
|
20 |
+
|
21 |
+
STRATEGIES = {
|
22 |
+
"1f1b": generate_1f1b_schedule,
|
23 |
+
"zb1p": generate_zero_bubble_1p_schedule,
|
24 |
+
"1f1b_overlap": generate_1f1b_overlap_schedule,
|
25 |
+
"1f1b_interleave": generate_1f1b_interleave_schedule,
|
26 |
+
"1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule,
|
27 |
+
"dualpipe": generate_dualpipe_schedule,
|
28 |
+
}
|
29 |
+
|
30 |
+
app = dash.Dash(__name__, suppress_callback_exceptions=True)
|
31 |
+
app.title = "Pipeline Parallelism Visualizer"
|
32 |
+
|
33 |
+
# Initial default values
|
34 |
+
default_values = {
|
35 |
+
"num_devices": 4,
|
36 |
+
"num_stages": 8,
|
37 |
+
"num_batches": 16,
|
38 |
+
"p2p_latency": 0.1,
|
39 |
+
"op_time_forward": 1.0,
|
40 |
+
"op_time_backward_d": 1.0,
|
41 |
+
"op_time_backward_w": 1.0,
|
42 |
+
"op_time_backward": 2.0,
|
43 |
+
"strategy": "1f1b_interleave",
|
44 |
+
"split_backward": False,
|
45 |
+
"placement_strategy": "interleave"
|
46 |
+
}
|
47 |
+
|
48 |
+
app.layout = html.Div([
|
49 |
+
html.H1("Pipeline Parallelism Schedule Visualizer", style={'textAlign': 'center'}),
|
50 |
+
|
51 |
+
html.Div([
|
52 |
+
html.Div([
|
53 |
+
html.Label("Number of Devices (GPUs):"),
|
54 |
+
dcc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1, style={'width': '100%'}),
|
55 |
+
|
56 |
+
html.Label("Number of Stages (Model Chunks):"),
|
57 |
+
dcc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1, style={'width': '100%'}),
|
58 |
+
|
59 |
+
html.Label("Number of Microbatches:"),
|
60 |
+
dcc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1, style={'width': '100%'}),
|
61 |
+
|
62 |
+
html.Label("P2P Latency (ms):"),
|
63 |
+
dcc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01, style={'width': '100%'}),
|
64 |
+
|
65 |
+
], style={'padding': 10, 'flex': 1}),
|
66 |
+
|
67 |
+
html.Div([
|
68 |
+
html.Label("Scheduling Strategy:"),
|
69 |
+
dcc.Dropdown(
|
70 |
+
id='strategy',
|
71 |
+
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
|
72 |
+
value=default_values["strategy"],
|
73 |
+
clearable=False,
|
74 |
+
style={'width': '100%'}
|
75 |
+
),
|
76 |
+
|
77 |
+
html.Label("Placement Strategy:"),
|
78 |
+
dcc.Dropdown(
|
79 |
+
id='placement_strategy',
|
80 |
+
options=[
|
81 |
+
{'label': 'Standard', 'value': 'standard'},
|
82 |
+
{'label': 'Interleave', 'value': 'interleave'},
|
83 |
+
{'label': 'DualPipe', 'value': 'dualpipe'}
|
84 |
+
],
|
85 |
+
value=default_values["placement_strategy"],
|
86 |
+
clearable=False,
|
87 |
+
style={'width': '100%'}
|
88 |
+
),
|
89 |
+
|
90 |
+
html.Div([ # Wrap checkbox and label
|
91 |
+
dcc.Checklist(
|
92 |
+
id='split_backward',
|
93 |
+
options=[{'label': ' Split Backward Pass (for ZB-1P, DualPipe)', 'value': 'True'}],
|
94 |
+
value=['True'] if default_values["split_backward"] else [],
|
95 |
+
style={'display': 'inline-block'}
|
96 |
+
),
|
97 |
+
], style={'marginTop': '20px'}),
|
98 |
+
|
99 |
+
], style={'padding': 10, 'flex': 1}),
|
100 |
+
|
101 |
+
html.Div([
|
102 |
+
html.Label("Operation Time - Forward (ms):"),
|
103 |
+
dcc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01, style={'width': '100%'}),
|
104 |
+
|
105 |
+
html.Label("Operation Time - Backward (ms):"),
|
106 |
+
dcc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01, style={'width': '100%'}),
|
107 |
+
|
108 |
+
html.Label("Operation Time - Backward D (Data Grad) (ms):"),
|
109 |
+
dcc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01, style={'width': '100%'}),
|
110 |
+
|
111 |
+
html.Label("Operation Time - Backward W (Weight Grad) (ms):"),
|
112 |
+
dcc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01, style={'width': '100%'}),
|
113 |
+
], style={'padding': 10, 'flex': 1}),
|
114 |
+
|
115 |
+
], style={'display': 'flex', 'flexDirection': 'row'}),
|
116 |
+
|
117 |
+
html.Div([
|
118 |
+
html.Button('Generate Schedule', id='generate-button', n_clicks=0, style={'margin': '20px auto', 'display': 'block'}),
|
119 |
+
]),
|
120 |
+
|
121 |
+
html.Div(id='error-message', style={'color': 'red', 'textAlign': 'center', 'marginTop': '10px'}),
|
122 |
+
|
123 |
+
dcc.Loading(
|
124 |
+
id="loading-graph",
|
125 |
+
type="circle",
|
126 |
+
children=dcc.Graph(id='pipeline-graph', figure=go.Figure())
|
127 |
+
)
|
128 |
+
])
|
129 |
+
|
130 |
+
@app.callback(
|
131 |
+
Output('pipeline-graph', 'figure'),
|
132 |
+
Output('error-message', 'children'),
|
133 |
+
Input('generate-button', 'n_clicks'),
|
134 |
+
State('num_devices', 'value'),
|
135 |
+
State('num_stages', 'value'),
|
136 |
+
State('num_batches', 'value'),
|
137 |
+
State('p2p_latency', 'value'),
|
138 |
+
State('op_time_forward', 'value'),
|
139 |
+
State('op_time_backward', 'value'),
|
140 |
+
State('op_time_backward_d', 'value'),
|
141 |
+
State('op_time_backward_w', 'value'),
|
142 |
+
State('strategy', 'value'),
|
143 |
+
State('split_backward', 'value'),
|
144 |
+
State('placement_strategy', 'value'),
|
145 |
+
prevent_initial_call=True
|
146 |
+
)
|
147 |
+
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
148 |
+
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
149 |
+
strategy, split_backward_list, placement_strategy):
|
150 |
+
|
151 |
+
error_message = ""
|
152 |
+
fig = go.Figure()
|
153 |
+
|
154 |
+
split_backward = 'True' in split_backward_list
|
155 |
+
|
156 |
+
# Basic Validations
|
157 |
+
if not all([num_devices, num_stages, num_batches, op_time_forward]):
|
158 |
+
return fig, "Missing required input values."
|
159 |
+
if split_backward and not all([op_time_backward_d, op_time_backward_w]):
|
160 |
+
return fig, "Backward D and Backward W times are required when 'Split Backward' is checked."
|
161 |
+
if not split_backward and not op_time_backward:
|
162 |
+
return fig, "Backward time is required when 'Split Backward' is unchecked."
|
163 |
+
if num_stages % num_devices != 0 and placement_strategy != 'dualpipe':
|
164 |
+
return fig, "Number of Stages must be divisible by Number of Devices for standard/interleave placement."
|
165 |
+
if placement_strategy == 'dualpipe' and num_stages % 2 != 0:
|
166 |
+
return fig, "DualPipe requires an even number of stages."
|
167 |
+
if placement_strategy == 'dualpipe' and num_stages != num_devices:
|
168 |
+
return fig, "DualPipe requires Number of Stages to be equal to Number of Devices."
|
169 |
+
if strategy == 'dualpipe' and not split_backward:
|
170 |
+
return fig, "DualPipe strategy currently requires 'Split Backward' to be checked."
|
171 |
+
if strategy == 'dualpipe' and placement_strategy != 'dualpipe':
|
172 |
+
return fig, "DualPipe strategy requires 'DualPipe' placement strategy."
|
173 |
+
if strategy == 'zb1p' and not split_backward:
|
174 |
+
return fig, "ZB-1P strategy requires 'Split Backward' to be checked."
|
175 |
+
|
176 |
+
try:
|
177 |
+
op_times = {
|
178 |
+
"forward": float(op_time_forward),
|
179 |
+
}
|
180 |
+
if split_backward:
|
181 |
+
op_times["backward_D"] = float(op_time_backward_d)
|
182 |
+
op_times["backward_W"] = float(op_time_backward_w)
|
183 |
+
# Add combined backward time for compatibility if needed by some visualization or calculation
|
184 |
+
op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w)
|
185 |
+
else:
|
186 |
+
op_times["backward"] = float(op_time_backward)
|
187 |
+
|
188 |
+
config = ScheduleConfig(
|
189 |
+
num_devices=int(num_devices),
|
190 |
+
num_stages=int(num_stages),
|
191 |
+
num_batches=int(num_batches),
|
192 |
+
p2p_latency=float(p2p_latency),
|
193 |
+
placement_strategy=placement_strategy,
|
194 |
+
split_backward=split_backward,
|
195 |
+
op_times=op_times,
|
196 |
+
)
|
197 |
+
|
198 |
+
schedule_func = STRATEGIES.get(strategy)
|
199 |
+
if not schedule_func:
|
200 |
+
raise ValueError(f"Invalid strategy selected: {strategy}")
|
201 |
+
|
202 |
+
schedule = schedule_func(config)
|
203 |
+
schedule.execute() # Calculate start/end times
|
204 |
+
|
205 |
+
vis_data = convert_schedule_to_visualization_format(schedule)
|
206 |
+
fig = create_pipeline_figure(vis_data, show_progress=False) # Disable progress bar in server mode
|
207 |
+
|
208 |
+
except AssertionError as e:
|
209 |
+
error_message = f"Configuration Error: {e}"
|
210 |
+
fig = go.Figure() # Return empty figure on error
|
211 |
+
except ValueError as e:
|
212 |
+
error_message = f"Input Error: {e}"
|
213 |
+
fig = go.Figure()
|
214 |
+
except Exception as e:
|
215 |
+
error_message = f"An unexpected error occurred: {e}"
|
216 |
+
fig = go.Figure()
|
217 |
+
|
218 |
+
return fig, error_message
|
219 |
+
|
220 |
+
if __name__ == '__main__':
|
221 |
+
port = 8050
|
222 |
+
# Timer(1, open_browser, args=(port,)).start() # Optional: automatically open browser
|
223 |
+
print(f"Dash server running on http://127.0.0.1:{port}")
|
224 |
+
app.run_server(debug=True, port=port)
|