Victarry commited on
Commit
912972f
·
1 Parent(s): 5126943

Improve UI components.

Browse files
Files changed (2) hide show
  1. pyproject.toml +2 -1
  2. src/server.py +165 -141
pyproject.toml CHANGED
@@ -24,6 +24,7 @@ dependencies = [
24
  "pandas>=2.1.0",
25
  "numpy>=1.26.0",
26
  "tqdm>=4.67.0",
 
27
  ]
28
 
29
  [project.optional-dependencies]
@@ -64,4 +65,4 @@ disallow_incomplete_defs = true
64
 
65
  [tool.pytest]
66
  testpaths = ["tests"]
67
- pythonpath = ["."]
 
24
  "pandas>=2.1.0",
25
  "numpy>=1.26.0",
26
  "tqdm>=4.67.0",
27
+ "dash-bootstrap-components>=1.7.1",
28
  ]
29
 
30
  [project.optional-dependencies]
 
65
 
66
  [tool.pytest]
67
  testpaths = ["tests"]
68
+ pythonpath = ["."]
src/server.py CHANGED
@@ -1,4 +1,5 @@
1
  import dash
 
2
  from dash import dcc, html, Input, Output, State, callback_context
3
  import plotly.graph_objects as go
4
  import webbrowser
@@ -27,8 +28,8 @@ STRATEGIES = {
27
  "dualpipe": generate_dualpipe_schedule,
28
  }
29
 
30
- app = dash.Dash(__name__, suppress_callback_exceptions=True)
31
- app.title = "Pipeline Parallelism Visualizer"
32
 
33
  # Initial default values
34
  default_values = {
@@ -45,91 +46,98 @@ default_values = {
45
  "placement_strategy": "interleave"
46
  }
47
 
48
- app.layout = html.Div([
49
- html.H1("Pipeline Parallelism Schedule Visualizer", style={'textAlign': 'center'}),
50
-
51
- html.Div([
52
  html.Div([
53
- html.Label("Number of Devices (GPUs):"),
54
- dcc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1, style={'width': '100%'}),
55
-
56
- html.Label("Number of Stages (Model Chunks):"),
57
- dcc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1, style={'width': '100%'}),
58
-
59
- html.Label("Number of Microbatches:"),
60
- dcc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1, style={'width': '100%'}),
61
-
62
- html.Label("P2P Latency (ms):"),
63
- dcc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01, style={'width': '100%'}),
64
-
65
- ], style={'padding': 10, 'flex': 1}),
 
 
 
 
66
 
 
 
 
67
  html.Div([
68
- html.Label("Scheduling Strategy:"),
69
- dcc.Dropdown(
70
- id='strategy',
71
  options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
72
- value=default_values["strategy"],
73
- clearable=False,
74
- style={'width': '100%'}
75
- ),
76
-
77
- html.Label("Placement Strategy:"),
78
- dcc.Dropdown(
79
- id='placement_strategy',
80
- options=[
81
- {'label': 'Standard', 'value': 'standard'},
82
- {'label': 'Interleave', 'value': 'interleave'},
83
- {'label': 'DualPipe', 'value': 'dualpipe'}
84
- ],
85
- value=default_values["placement_strategy"],
86
- clearable=False,
87
- style={'width': '100%'}
88
  ),
 
 
 
89
 
90
- html.Div([ # Wrap checkbox and label
91
- dcc.Checklist(
92
- id='split_backward',
93
- options=[{'label': ' Split Backward Pass (for ZB-1P, DualPipe)', 'value': 'True'}],
94
- value=['True'] if default_values["split_backward"] else [],
95
- style={'display': 'inline-block'}
96
- ),
97
- ], style={'marginTop': '20px'}),
98
-
99
- ], style={'padding': 10, 'flex': 1}),
100
-
101
  html.Div([
102
- html.Label("Operation Time - Forward (ms):"),
103
- dcc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01, style={'width': '100%'}),
104
-
105
- html.Label("Operation Time - Backward (ms):"),
106
- dcc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01, style={'width': '100%'}),
107
-
108
- html.Label("Operation Time - Backward D (Data Grad) (ms):"),
109
- dcc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01, style={'width': '100%'}),
110
-
111
- html.Label("Operation Time - Backward W (Weight Grad) (ms):"),
112
- dcc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01, style={'width': '100%'}),
113
- ], style={'padding': 10, 'flex': 1}),
 
 
 
 
 
 
 
 
114
 
115
- ], style={'display': 'flex', 'flexDirection': 'row'}),
 
 
116
 
117
- html.Div([
118
- html.Button('Generate Schedule', id='generate-button', n_clicks=0, style={'margin': '20px auto', 'display': 'block'}),
 
 
119
  ]),
120
 
121
- html.Div(id='error-message', style={'color': 'red', 'textAlign': 'center', 'marginTop': '10px'}),
 
 
 
 
122
 
123
- dcc.Loading(
124
- id="loading-graph",
125
- type="circle",
126
- children=dcc.Graph(id='pipeline-graph', figure=go.Figure())
127
- )
128
- ])
 
 
 
 
129
 
130
  @app.callback(
131
- Output('pipeline-graph', 'figure'),
132
- Output('error-message', 'children'),
133
  Input('generate-button', 'n_clicks'),
134
  State('num_devices', 'value'),
135
  State('num_stages', 'value'),
@@ -139,83 +147,99 @@ app.layout = html.Div([
139
  State('op_time_backward', 'value'),
140
  State('op_time_backward_d', 'value'),
141
  State('op_time_backward_w', 'value'),
142
- State('strategy', 'value'),
143
- State('split_backward', 'value'),
144
- State('placement_strategy', 'value'),
145
  prevent_initial_call=True
146
  )
147
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
148
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
149
- strategy, split_backward_list, placement_strategy):
150
 
151
- error_message = ""
152
- fig = go.Figure()
153
 
154
- split_backward = 'True' in split_backward_list
 
155
 
156
- # Basic Validations
157
  if not all([num_devices, num_stages, num_batches, op_time_forward]):
158
- return fig, "Missing required input values."
159
- if split_backward and not all([op_time_backward_d, op_time_backward_w]):
160
- return fig, "Backward D and Backward W times are required when 'Split Backward' is checked."
161
- if not split_backward and not op_time_backward:
162
- return fig, "Backward time is required when 'Split Backward' is unchecked."
163
- if num_stages % num_devices != 0 and placement_strategy != 'dualpipe':
164
- return fig, "Number of Stages must be divisible by Number of Devices for standard/interleave placement."
165
- if placement_strategy == 'dualpipe' and num_stages % 2 != 0:
166
- return fig, "DualPipe requires an even number of stages."
167
- if placement_strategy == 'dualpipe' and num_stages != num_devices:
168
- return fig, "DualPipe requires Number of Stages to be equal to Number of Devices."
169
- if strategy == 'dualpipe' and not split_backward:
170
- return fig, "DualPipe strategy currently requires 'Split Backward' to be checked."
171
- if strategy == 'dualpipe' and placement_strategy != 'dualpipe':
172
- return fig, "DualPipe strategy requires 'DualPipe' placement strategy."
173
- if strategy == 'zb1p' and not split_backward:
174
- return fig, "ZB-1P strategy requires 'Split Backward' to be checked."
175
-
176
- try:
177
- op_times = {
178
- "forward": float(op_time_forward),
179
- }
180
- if split_backward:
181
- op_times["backward_D"] = float(op_time_backward_d)
182
- op_times["backward_W"] = float(op_time_backward_w)
183
- # Add combined backward time for compatibility if needed by some visualization or calculation
184
- op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w)
185
- else:
186
- op_times["backward"] = float(op_time_backward)
187
-
188
- config = ScheduleConfig(
189
- num_devices=int(num_devices),
190
- num_stages=int(num_stages),
191
- num_batches=int(num_batches),
192
- p2p_latency=float(p2p_latency),
193
- placement_strategy=placement_strategy,
194
- split_backward=split_backward,
195
- op_times=op_times,
196
- )
197
-
198
- schedule_func = STRATEGIES.get(strategy)
199
- if not schedule_func:
200
- raise ValueError(f"Invalid strategy selected: {strategy}")
201
-
202
- schedule = schedule_func(config)
203
- schedule.execute() # Calculate start/end times
204
-
205
- vis_data = convert_schedule_to_visualization_format(schedule)
206
- fig = create_pipeline_figure(vis_data, show_progress=False) # Disable progress bar in server mode
207
-
208
- except AssertionError as e:
209
- error_message = f"Configuration Error: {e}"
210
- fig = go.Figure() # Return empty figure on error
211
- except ValueError as e:
212
- error_message = f"Input Error: {e}"
213
- fig = go.Figure()
214
- except Exception as e:
215
- error_message = f"An unexpected error occurred: {e}"
216
- fig = go.Figure()
217
 
218
- return fig, error_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  if __name__ == '__main__':
221
  port = 8050
 
1
  import dash
2
+ import dash_bootstrap_components as dbc
3
  from dash import dcc, html, Input, Output, State, callback_context
4
  import plotly.graph_objects as go
5
  import webbrowser
 
28
  "dualpipe": generate_dualpipe_schedule,
29
  }
30
 
31
+ app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
32
+ app.title = "Pipeline Parallelism Schedule Visualizer"
33
 
34
  # Initial default values
35
  default_values = {
 
46
  "placement_strategy": "interleave"
47
  }
48
 
49
+ # Define input groups using dbc components
50
+ basic_params_card = dbc.Card(
51
+ dbc.CardBody([
52
+ html.H5("Basic Parameters", className="card-title"),
53
  html.Div([
54
+ dbc.Label("Number of Devices (GPUs):"),
55
+ dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1),
56
+ ], className="mb-3"),
57
+ html.Div([
58
+ dbc.Label("Number of Stages (Model Chunks):"),
59
+ dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1),
60
+ ], className="mb-3"),
61
+ html.Div([
62
+ dbc.Label("Number of Microbatches:"),
63
+ dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1),
64
+ ], className="mb-3"),
65
+ html.Div([
66
+ dbc.Label("P2P Latency (ms):"),
67
+ dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01),
68
+ ], className="mb-3"),
69
+ ])
70
+ )
71
 
72
+ scheduling_params_card = dbc.Card(
73
+ dbc.CardBody([
74
+ html.H5("Scheduling Parameters", className="card-title"),
75
  html.Div([
76
+ dbc.Label("Scheduling Strategies:"),
77
+ dbc.Checklist(
78
+ id='strategy-checklist',
79
  options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
80
+ value=[default_values["strategy"]],
81
+ inline=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ),
83
+ ], className="mb-3"),
84
+ ])
85
+ )
86
 
87
+ timing_params_card = dbc.Card(
88
+ dbc.CardBody([
89
+ html.H5("Operation Timing (ms)", className="card-title"),
 
 
 
 
 
 
 
 
90
  html.Div([
91
+ dbc.Label("Forward:"),
92
+ dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01),
93
+ ], className="mb-3"),
94
+ html.Div([
95
+ dbc.Label("Backward (Combined):"),
96
+ dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
97
+ dbc.FormText("Used when strategy does NOT require split backward."),
98
+ ], className="mb-3"),
99
+ html.Div([
100
+ dbc.Label("Backward D (Data Grad):"),
101
+ dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
102
+ dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
103
+ ], className="mb-3"),
104
+ html.Div([
105
+ dbc.Label("Backward W (Weight Grad):"),
106
+ dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
107
+ dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
108
+ ], className="mb-3"),
109
+ ])
110
+ )
111
 
112
+ # Updated app layout using dbc components and structure
113
+ app.layout = dbc.Container([
114
+ html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
115
 
116
+ dbc.Row([
117
+ dbc.Col(basic_params_card, md=4),
118
+ dbc.Col(scheduling_params_card, md=4),
119
+ dbc.Col(timing_params_card, md=4),
120
  ]),
121
 
122
+ dbc.Row([
123
+ dbc.Col([
124
+ dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"),
125
+ ], className="text-center")
126
+ ]),
127
 
128
+ dbc.Row([
129
+ dbc.Col([
130
+ dcc.Loading(
131
+ id="loading-graph-area",
132
+ type="circle",
133
+ children=html.Div(id='graph-output-container', className="mt-4")
134
+ )
135
+ ])
136
+ ])
137
+ ], fluid=True)
138
 
139
  @app.callback(
140
+ Output('graph-output-container', 'children'),
 
141
  Input('generate-button', 'n_clicks'),
142
  State('num_devices', 'value'),
143
  State('num_stages', 'value'),
 
147
  State('op_time_backward', 'value'),
148
  State('op_time_backward_d', 'value'),
149
  State('op_time_backward_w', 'value'),
150
+ State('strategy-checklist', 'value'),
 
 
151
  prevent_initial_call=True
152
  )
153
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
154
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
155
+ selected_strategies):
156
 
157
+ output_components = []
 
158
 
159
+ if not selected_strategies:
160
+ return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
161
 
 
162
  if not all([num_devices, num_stages, num_batches, op_time_forward]):
163
+ return [dbc.Alert("Missing required basic input values (Devices, Stages, Batches, Forward Time).", color="danger")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ for strategy in selected_strategies:
166
+ error_message = ""
167
+ fig = go.Figure()
168
+ placement_strategy = ""
169
+
170
+ split_backward = strategy in ["zb1p", "dualpipe"]
171
+
172
+ if split_backward and not all([op_time_backward_d, op_time_backward_w]):
173
+ error_message = f"Strategy '{strategy}': Backward D and Backward W times are required."
174
+ elif not split_backward and not op_time_backward:
175
+ error_message = f"Strategy '{strategy}': Combined Backward time is required."
176
+
177
+ if not error_message:
178
+ if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
179
+ placement_strategy = "standard"
180
+ if num_devices != num_stages:
181
+ error_message = f"Strategy '{strategy}': Requires Number of Stages == Number of Devices."
182
+ elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
183
+ placement_strategy = "interleave"
184
+ if num_stages % num_devices != 0:
185
+ error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
186
+ elif strategy == "dualpipe":
187
+ placement_strategy = "dualpipe"
188
+ if num_stages % 2 != 0:
189
+ error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
190
+ elif num_stages != num_devices:
191
+ error_message = f"Strategy '{strategy}' (DualPipe): Requires Number of Stages == Number of Devices."
192
+
193
+ if not error_message:
194
+ try:
195
+ op_times = { "forward": float(op_time_forward) }
196
+ if split_backward:
197
+ op_times["backward_D"] = float(op_time_backward_d)
198
+ op_times["backward_W"] = float(op_time_backward_w)
199
+ op_times["backward"] = float(op_time_backward_d) + float(op_time_backward_w)
200
+ else:
201
+ op_times["backward"] = float(op_time_backward)
202
+
203
+ config = ScheduleConfig(
204
+ num_devices=int(num_devices),
205
+ num_stages=int(num_stages),
206
+ num_batches=int(num_batches),
207
+ p2p_latency=float(p2p_latency),
208
+ placement_strategy=placement_strategy,
209
+ split_backward=split_backward,
210
+ op_times=op_times,
211
+ )
212
+
213
+ schedule_func = STRATEGIES.get(strategy)
214
+ if not schedule_func:
215
+ raise ValueError(f"Invalid strategy function for: {strategy}")
216
+
217
+ schedule = schedule_func(config)
218
+ schedule.execute()
219
+
220
+ vis_data = convert_schedule_to_visualization_format(schedule)
221
+ fig = create_pipeline_figure(vis_data, show_progress=False)
222
+
223
+ output_components.append(html.Div([
224
+ html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
225
+ dcc.Graph(figure=fig)
226
+ ]))
227
+
228
+ except (AssertionError, ValueError, TypeError) as e:
229
+ error_message = f"Error generating schedule for '{strategy}': {e}"
230
+ import traceback
231
+ traceback.print_exc()
232
+ except Exception as e:
233
+ error_message = f"An unexpected error occurred for '{strategy}': {e}"
234
+ import traceback
235
+ traceback.print_exc()
236
+
237
+ if error_message:
238
+ output_components.append(
239
+ dbc.Alert(error_message, color="danger", className="mt-3")
240
+ )
241
+
242
+ return output_components
243
 
244
  if __name__ == '__main__':
245
  port = 8050