Victarry commited on
Commit
5126943
·
1 Parent(s): e1ba92f

Add interactive server.

Browse files
Files changed (3) hide show
  1. README.md +15 -2
  2. pyproject.toml +1 -1
  3. src/server.py +224 -0
README.md CHANGED
@@ -1,4 +1,4 @@
1
- # Pipeline Parallelism Emulation
2
 
3
  This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
4
 
@@ -37,7 +37,20 @@ Setup `uv` if not installed on your computer:
37
  curl -LsSf https://astral.sh/uv/install.sh | sh
38
  ```
39
 
40
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  ### Running for 1F1B strategy:
43
  ```bash
 
1
+ # Pipeline Parallelism Emulation and Visualization
2
 
3
  This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
4
 
 
37
  curl -LsSf https://astral.sh/uv/install.sh | sh
38
  ```
39
 
40
+
41
+ ## Running the Interactive Server
42
+
43
+ To visualize schedules interactively:
44
+
45
+ ```bash
46
+ uv run src/server.py
47
+ ```
48
+
49
+ This will start a Dash server (usually on `http://127.0.0.1:8050/`). Open this URL in your web browser.
50
+
51
+ You can then adjust parameters like the number of devices, stages, batches, operation times, and select different scheduling strategies to see the resulting pipeline visualization.
52
+
53
+ ## Running from Command Line
54
 
55
  ### Running for 1F1B strategy:
56
  ```bash
pyproject.toml CHANGED
@@ -9,7 +9,7 @@ description = "Pipeline Parallelism Emulation and Visualization"
9
  readme = "README.md"
10
  requires-python = ">=3.10"
11
  authors = [
12
- {name = "Project Author"}
13
  ]
14
  classifiers = [
15
  "Programming Language :: Python :: 3",
 
9
  readme = "README.md"
10
  requires-python = ">=3.10"
11
  authors = [
12
+ {name = "Zhenhuan Liu"}
13
  ]
14
  classifiers = [
15
  "Programming Language :: Python :: 3",
src/server.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash
2
+ from dash import dcc, html, Input, Output, State, callback_context
3
+ import plotly.graph_objects as go
4
+ import webbrowser
5
+ from threading import Timer
6
+
7
+ from src.execution_model import ScheduleConfig, Schedule
8
+ from src.strategies import (
9
+ generate_1f1b_schedule,
10
+ generate_zero_bubble_1p_schedule,
11
+ generate_1f1b_overlap_schedule,
12
+ generate_1f1b_interleave_schedule,
13
+ generate_1f1b_interleave_overlap_schedule,
14
+ generate_dualpipe_schedule
15
+ )
16
+ from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure
17
+
18
+ def open_browser(port):
19
+ webbrowser.open_new(f"http://127.0.0.1:{port}")
20
+
21
+ STRATEGIES = {
22
+ "1f1b": generate_1f1b_schedule,
23
+ "zb1p": generate_zero_bubble_1p_schedule,
24
+ "1f1b_overlap": generate_1f1b_overlap_schedule,
25
+ "1f1b_interleave": generate_1f1b_interleave_schedule,
26
+ "1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule,
27
+ "dualpipe": generate_dualpipe_schedule,
28
+ }
29
+
30
+ app = dash.Dash(__name__, suppress_callback_exceptions=True)
31
+ app.title = "Pipeline Parallelism Visualizer"
32
+
33
+ # Initial default values
34
+ default_values = {
35
+ "num_devices": 4,
36
+ "num_stages": 8,
37
+ "num_batches": 16,
38
+ "p2p_latency": 0.1,
39
+ "op_time_forward": 1.0,
40
+ "op_time_backward_d": 1.0,
41
+ "op_time_backward_w": 1.0,
42
+ "op_time_backward": 2.0,
43
+ "strategy": "1f1b_interleave",
44
+ "split_backward": False,
45
+ "placement_strategy": "interleave"
46
+ }
47
+
48
+ app.layout = html.Div([
49
+ html.H1("Pipeline Parallelism Schedule Visualizer", style={'textAlign': 'center'}),
50
+
51
+ html.Div([
52
+ html.Div([
53
+ html.Label("Number of Devices (GPUs):"),
54
+ dcc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1, style={'width': '100%'}),
55
+
56
+ html.Label("Number of Stages (Model Chunks):"),
57
+ dcc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1, style={'width': '100%'}),
58
+
59
+ html.Label("Number of Microbatches:"),
60
+ dcc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1, style={'width': '100%'}),
61
+
62
+ html.Label("P2P Latency (ms):"),
63
+ dcc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01, style={'width': '100%'}),
64
+
65
+ ], style={'padding': 10, 'flex': 1}),
66
+
67
+ html.Div([
68
+ html.Label("Scheduling Strategy:"),
69
+ dcc.Dropdown(
70
+ id='strategy',
71
+ options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
72
+ value=default_values["strategy"],
73
+ clearable=False,
74
+ style={'width': '100%'}
75
+ ),
76
+
77
+ html.Label("Placement Strategy:"),
78
+ dcc.Dropdown(
79
+ id='placement_strategy',
80
+ options=[
81
+ {'label': 'Standard', 'value': 'standard'},
82
+ {'label': 'Interleave', 'value': 'interleave'},
83
+ {'label': 'DualPipe', 'value': 'dualpipe'}
84
+ ],
85
+ value=default_values["placement_strategy"],
86
+ clearable=False,
87
+ style={'width': '100%'}
88
+ ),
89
+
90
+ html.Div([ # Wrap checkbox and label
91
+ dcc.Checklist(
92
+ id='split_backward',
93
+ options=[{'label': ' Split Backward Pass (for ZB-1P, DualPipe)', 'value': 'True'}],
94
+ value=['True'] if default_values["split_backward"] else [],
95
+ style={'display': 'inline-block'}
96
+ ),
97
+ ], style={'marginTop': '20px'}),
98
+
99
+ ], style={'padding': 10, 'flex': 1}),
100
+
101
+ html.Div([
102
+ html.Label("Operation Time - Forward (ms):"),
103
+ dcc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01, style={'width': '100%'}),
104
+
105
+ html.Label("Operation Time - Backward (ms):"),
106
+ dcc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01, style={'width': '100%'}),
107
+
108
+ html.Label("Operation Time - Backward D (Data Grad) (ms):"),
109
+ dcc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01, style={'width': '100%'}),
110
+
111
+ html.Label("Operation Time - Backward W (Weight Grad) (ms):"),
112
+ dcc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01, style={'width': '100%'}),
113
+ ], style={'padding': 10, 'flex': 1}),
114
+
115
+ ], style={'display': 'flex', 'flexDirection': 'row'}),
116
+
117
+ html.Div([
118
+ html.Button('Generate Schedule', id='generate-button', n_clicks=0, style={'margin': '20px auto', 'display': 'block'}),
119
+ ]),
120
+
121
+ html.Div(id='error-message', style={'color': 'red', 'textAlign': 'center', 'marginTop': '10px'}),
122
+
123
+ dcc.Loading(
124
+ id="loading-graph",
125
+ type="circle",
126
+ children=dcc.Graph(id='pipeline-graph', figure=go.Figure())
127
+ )
128
+ ])
129
+
130
+ @app.callback(
131
+ Output('pipeline-graph', 'figure'),
132
+ Output('error-message', 'children'),
133
+ Input('generate-button', 'n_clicks'),
134
+ State('num_devices', 'value'),
135
+ State('num_stages', 'value'),
136
+ State('num_batches', 'value'),
137
+ State('p2p_latency', 'value'),
138
+ State('op_time_forward', 'value'),
139
+ State('op_time_backward', 'value'),
140
+ State('op_time_backward_d', 'value'),
141
+ State('op_time_backward_w', 'value'),
142
+ State('strategy', 'value'),
143
+ State('split_backward', 'value'),
144
+ State('placement_strategy', 'value'),
145
+ prevent_initial_call=True
146
+ )
147
+ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
148
+ op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
149
+ strategy, split_backward_list, placement_strategy):
150
+
151
+ error_message = ""
152
+ fig = go.Figure()
153
+
154
+ split_backward = 'True' in split_backward_list
155
+
156
+ # Basic Validations
157
+ if not all([num_devices, num_stages, num_batches, op_time_forward]):
158
+ return fig, "Missing required input values."
159
+ if split_backward and not all([op_time_backward_d, op_time_backward_w]):
160
+ return fig, "Backward D and Backward W times are required when 'Split Backward' is checked."
161
+ if not split_backward and not op_time_backward:
162
+ return fig, "Backward time is required when 'Split Backward' is unchecked."
163
+ if num_stages % num_devices != 0 and placement_strategy != 'dualpipe':
164
+ return fig, "Number of Stages must be divisible by Number of Devices for standard/interleave placement."
165
+ if placement_strategy == 'dualpipe' and num_stages % 2 != 0:
166
+ return fig, "DualPipe requires an even number of stages."
167
+ if placement_strategy == 'dualpipe' and num_stages != num_devices:
168
+ return fig, "DualPipe requires Number of Stages to be equal to Number of Devices."
169
+ if strategy == 'dualpipe' and not split_backward:
170
+ return fig, "DualPipe strategy currently requires 'Split Backward' to be checked."
171
+ if strategy == 'dualpipe' and placement_strategy != 'dualpipe':
172
+ return fig, "DualPipe strategy requires 'DualPipe' placement strategy."
173
+ if strategy == 'zb1p' and not split_backward:
174
+ return fig, "ZB-1P strategy requires 'Split Backward' to be checked."
175
+
176
+ try:
177
+ op_times = {
178
+ "forward": float(op_time_forward),
179
+ }
180
+ if split_backward:
181
+ op_times["backward_D"] = float(op_time_backward_d)
182
+ op_times["backward_W"] = float(op_time_backward_w)
183
+ # Add combined backward time for compatibility if needed by some visualization or calculation
184
+ op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w)
185
+ else:
186
+ op_times["backward"] = float(op_time_backward)
187
+
188
+ config = ScheduleConfig(
189
+ num_devices=int(num_devices),
190
+ num_stages=int(num_stages),
191
+ num_batches=int(num_batches),
192
+ p2p_latency=float(p2p_latency),
193
+ placement_strategy=placement_strategy,
194
+ split_backward=split_backward,
195
+ op_times=op_times,
196
+ )
197
+
198
+ schedule_func = STRATEGIES.get(strategy)
199
+ if not schedule_func:
200
+ raise ValueError(f"Invalid strategy selected: {strategy}")
201
+
202
+ schedule = schedule_func(config)
203
+ schedule.execute() # Calculate start/end times
204
+
205
+ vis_data = convert_schedule_to_visualization_format(schedule)
206
+ fig = create_pipeline_figure(vis_data, show_progress=False) # Disable progress bar in server mode
207
+
208
+ except AssertionError as e:
209
+ error_message = f"Configuration Error: {e}"
210
+ fig = go.Figure() # Return empty figure on error
211
+ except ValueError as e:
212
+ error_message = f"Input Error: {e}"
213
+ fig = go.Figure()
214
+ except Exception as e:
215
+ error_message = f"An unexpected error occurred: {e}"
216
+ fig = go.Figure()
217
+
218
+ return fig, error_message
219
+
220
+ if __name__ == '__main__':
221
+ port = 8050
222
+ # Timer(1, open_browser, args=(port,)).start() # Optional: automatically open browser
223
+ print(f"Dash server running on http://127.0.0.1:{port}")
224
+ app.run_server(debug=True, port=port)