Victarry commited on
Commit
5b28831
·
1 Parent(s): 2ae9b28

Improve visualizer performance.

Browse files
Files changed (1) hide show
  1. src/visualizer.py +143 -87
src/visualizer.py CHANGED
@@ -2,10 +2,9 @@ import dash
2
  from dash import dcc, html
3
  from dash.dependencies import Input, Output
4
  import plotly.graph_objects as go
5
- import argparse
6
- from typing import List, Dict, Literal, Optional
7
  from tqdm import tqdm
8
- import base64
9
 
10
  from src.execution_model import Schedule
11
 
@@ -40,6 +39,49 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
40
  return visualization_data
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True):
44
  """
45
  Create a Plotly figure for pipeline parallelism scheduling.
@@ -51,49 +93,9 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
51
  """
52
  # Find the number of devices
53
  num_devices = len(schedule_data)
54
-
55
  empty_color = "whitesmoke"
56
- # Colors for task types
57
- def get_color(op_type: str, stage_id: int):
58
- # Color palettes for different virtual stages
59
- forward_colors = [
60
- "royalblue", # Stage 0
61
- "lightskyblue", # Stage 1
62
- "cornflowerblue", # Stage 2
63
- "steelblue", # Stage 3
64
- "dodgerblue", # Stage 4
65
- "deepskyblue", # Stage 5
66
- "mediumblue", # Stage 6
67
- "mediumslateblue",# Stage 7
68
- "slateblue", # Stage 8
69
- "darkslateblue" # Stage 9
70
- ]
71
-
72
- backward_colors = [
73
- "lightgreen", # Stage 0
74
- "mediumseagreen", # Stage 1
75
- "seagreen", # Stage 2
76
- "lightseagreen", # Stage 3
77
- "mediumaquamarine", # Stage 4
78
- "mediumspringgreen", # Stage 5
79
- "springgreen", # Stage 6
80
- "palegreen", # Stage 7
81
- "limegreen", # Stage 8
82
- "forestgreen" # Stage 9
83
- ]
84
-
85
- virtual_stage = stage_id // num_devices
86
-
87
- # If virtual_stage is beyond our color list, cycle through the colors
88
- color_index = virtual_stage % len(forward_colors)
89
-
90
- if op_type == "forward":
91
- return forward_colors[color_index]
92
- elif op_type == "backward":
93
- return backward_colors[color_index]
94
- else:
95
- raise ValueError(f"Invalid operation type: {op_type}")
96
-
97
  # Find the maximum time in the schedule if not provided
98
  if max_time is None:
99
  max_time = 0
@@ -116,6 +118,11 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
116
  # Create a custom y-axis with no gaps between devices
117
  y_spacing = 1.0 # Use 1.0 for no gaps
118
 
 
 
 
 
 
119
  # Add rectangles for each task
120
  for device_idx, device in enumerate(schedule_data):
121
  device_idx_reversed = num_devices - device_idx - 1
@@ -126,11 +133,11 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
126
  for task in sorted_tasks:
127
  # Determine task color and text color
128
  if task["type"] == "forward":
129
- color = get_color(task["type"], task["stage"])
130
  text_color = "white"
131
  name = "Forward"
132
  elif task["type"] == "backward":
133
- color = get_color(task["type"], task["stage"])
134
  text_color = "black"
135
  name = "Backward"
136
  else:
@@ -145,8 +152,8 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
145
  # Calculate y positions with no gaps
146
  y_pos = device_idx_reversed * y_spacing
147
 
148
- # Create rectangle using shape
149
- fig.add_shape(
150
  type="rect",
151
  x0=start_time,
152
  y0=y_pos - 0.5,
@@ -155,25 +162,27 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
155
  line=dict(color="black", width=0.5),
156
  fillcolor=color,
157
  layer="above",
158
- )
159
 
160
- # Add batch number text
161
- fig.add_annotation(
162
  x=start_time + duration / 2,
163
  y=y_pos,
164
- text=f"{task['batch']}", # Only show batch ID
165
  showarrow=False,
166
- font=dict(color=text_color, size=12, family="Arial, bold"), # Increased font size
167
- )
168
 
169
- # Add hover data with additional details
170
- fig.add_trace(go.Scatter(
 
 
171
  x=[start_time + duration / 2],
172
  y=[y_pos],
173
  mode='markers',
174
  marker=dict(opacity=0), # Invisible marker
175
  hoverinfo='text',
176
- text=f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}",
177
  showlegend=False
178
  ))
179
 
@@ -182,6 +191,16 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
182
  tasks_processed += 1
183
  progress_bar.update(1)
184
 
 
 
 
 
 
 
 
 
 
 
185
  # Add custom legend
186
  legend_items = []
187
 
@@ -196,18 +215,18 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
196
  for vs in range(max_virtual_stage + 1):
197
  legend_items.append(dict(
198
  name=f"Forward (VS {vs})",
199
- color=get_color("forward", vs * num_devices)
200
  ))
201
  legend_items.append(dict(
202
  name=f"Backward (VS {vs})",
203
- color=get_color("backward", vs * num_devices)
204
  ))
205
 
206
  # If no tasks found, add default legend items
207
  if not legend_items:
208
  legend_items = [
209
- dict(name="Forward (VS 0)", color=get_color("forward", 0)),
210
- dict(name="Backward (VS 0)", color=get_color("backward", 0)),
211
  ]
212
 
213
  for i, item in enumerate(legend_items):
@@ -232,6 +251,8 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
232
  # Adjust the range to ensure there are no empty spaces at the end
233
  x_end = max_time * 1.05 # Add a small margin
234
 
 
 
235
  fig.update_layout(
236
  yaxis=dict(
237
  tickmode="array",
@@ -243,7 +264,7 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
243
  margin=dict(l=50, r=20, t=40, b=40),
244
  plot_bgcolor="white",
245
  title=dict(
246
- text="Pipeline Parallelism Schedule",
247
  x=0.5,
248
  y=0.98, # Move title position closer to the top
249
  font=dict(size=20)
@@ -271,51 +292,84 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
271
  return fig
272
 
273
 
274
- def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
 
 
 
275
  """
276
  Create a Dash app to visualize the pipeline schedule.
277
 
278
  Args:
279
  schedule: Schedule object to visualize
280
- schedule_type: Type of schedule ("1f1b" or other)
 
281
  """
282
- # Convert schedule to visualization format
283
- schedule_data = convert_schedule_to_visualization_format(schedule)
 
284
 
285
- # Create the app
286
- app = dash.Dash(__name__, title=f"Pipeline Parallelism Visualizer - {schedule_type}")
 
 
 
 
 
 
287
 
 
 
 
 
 
 
 
288
  app.layout = html.Div([
289
- html.H1(f"Pipeline Parallelism Visualizer - {schedule_type}", style={'textAlign': 'center'}),
290
 
291
  html.Div([
292
- html.Div([
293
- html.H3("Schedule Configuration:"),
294
- html.Ul([
295
- html.Li(f"Number of devices: {schedule.config.num_devices}"),
296
- html.Li(f"Number of stages: {schedule.config.num_stages}"),
297
- html.Li(f"Number of batches: {schedule.config.num_batches}"),
298
- ]),
299
- ], className="config-section"),
300
-
301
- ], style={'margin': '20px'}),
302
 
303
  html.Div(id="graph-container", children=[]),
304
 
305
- dcc.Graph(
306
- id="pipeline-graph",
307
- config={'displayModeBar': True, 'toImageButtonOptions': {'format': 'png', 'filename': 'pipeline_visualization'}}
 
 
 
 
 
 
308
  ),
309
  ])
310
 
 
 
 
311
  @app.callback(
312
  Output("pipeline-graph", "figure"),
313
  Input("graph-container", "children"),
314
  prevent_initial_call=False,
315
  )
316
  def load_graph(_):
317
- # Create the figure when the app loads
318
- return create_pipeline_figure(schedule_data, show_progress=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  return app
321
 
@@ -323,7 +377,8 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
323
  def visualize_pipeline_parallelism_dash(
324
  schedule: Schedule,
325
  port: int = 8050,
326
- debug: bool = False
 
327
  ):
328
  """
329
  Launch a Dash app to visualize the pipeline schedule interactively.
@@ -332,7 +387,8 @@ def visualize_pipeline_parallelism_dash(
332
  schedule: Schedule object to visualize
333
  port: Port to run the Dash app on
334
  debug: Whether to run the Dash app in debug mode
 
335
  """
336
- app = create_dash_app(schedule)
337
  print(f"Starting Dash app on http://localhost:{port}/")
338
  app.run_server(debug=debug, port=port)
 
2
  from dash import dcc, html
3
  from dash.dependencies import Input, Output
4
  import plotly.graph_objects as go
5
+ from typing import List, Dict
 
6
  from tqdm import tqdm
7
+ from functools import lru_cache
8
 
9
  from src.execution_model import Schedule
10
 
 
39
  return visualization_data
40
 
41
 
42
+ # Cache the color calculation as it's repeatedly called with the same parameters
43
+ @lru_cache(maxsize=128)
44
+ def get_color(op_type: str, stage_id: int, num_devices: int):
45
+ # Color palettes for different virtual stages
46
+ forward_colors = [
47
+ "royalblue", # Stage 0
48
+ "lightskyblue", # Stage 1
49
+ "cornflowerblue", # Stage 2
50
+ "steelblue", # Stage 3
51
+ "dodgerblue", # Stage 4
52
+ "deepskyblue", # Stage 5
53
+ "mediumblue", # Stage 6
54
+ "mediumslateblue",# Stage 7
55
+ "slateblue", # Stage 8
56
+ "darkslateblue" # Stage 9
57
+ ]
58
+
59
+ backward_colors = [
60
+ "lightgreen", # Stage 0
61
+ "mediumseagreen", # Stage 1
62
+ "seagreen", # Stage 2
63
+ "lightseagreen", # Stage 3
64
+ "mediumaquamarine", # Stage 4
65
+ "mediumspringgreen", # Stage 5
66
+ "springgreen", # Stage 6
67
+ "palegreen", # Stage 7
68
+ "limegreen", # Stage 8
69
+ "forestgreen" # Stage 9
70
+ ]
71
+
72
+ virtual_stage = stage_id // num_devices
73
+
74
+ # If virtual_stage is beyond our color list, cycle through the colors
75
+ color_index = virtual_stage % len(forward_colors)
76
+
77
+ if op_type == "forward":
78
+ return forward_colors[color_index]
79
+ elif op_type == "backward":
80
+ return backward_colors[color_index]
81
+ else:
82
+ raise ValueError(f"Invalid operation type: {op_type}")
83
+
84
+
85
  def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True):
86
  """
87
  Create a Plotly figure for pipeline parallelism scheduling.
 
93
  """
94
  # Find the number of devices
95
  num_devices = len(schedule_data)
96
+
97
  empty_color = "whitesmoke"
98
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Find the maximum time in the schedule if not provided
100
  if max_time is None:
101
  max_time = 0
 
118
  # Create a custom y-axis with no gaps between devices
119
  y_spacing = 1.0 # Use 1.0 for no gaps
120
 
121
+ # Batch processing for increased performance
122
+ shapes = []
123
+ annotations = []
124
+ hover_traces = []
125
+
126
  # Add rectangles for each task
127
  for device_idx, device in enumerate(schedule_data):
128
  device_idx_reversed = num_devices - device_idx - 1
 
133
  for task in sorted_tasks:
134
  # Determine task color and text color
135
  if task["type"] == "forward":
136
+ color = get_color(task["type"], task["stage"], num_devices)
137
  text_color = "white"
138
  name = "Forward"
139
  elif task["type"] == "backward":
140
+ color = get_color(task["type"], task["stage"], num_devices)
141
  text_color = "black"
142
  name = "Backward"
143
  else:
 
152
  # Calculate y positions with no gaps
153
  y_pos = device_idx_reversed * y_spacing
154
 
155
+ # Create rectangle using shape (batch-add later)
156
+ shapes.append(dict(
157
  type="rect",
158
  x0=start_time,
159
  y0=y_pos - 0.5,
 
162
  line=dict(color="black", width=0.5),
163
  fillcolor=color,
164
  layer="above",
165
+ ))
166
 
167
+ # Add batch number text (batch-add later)
168
+ annotations.append(dict(
169
  x=start_time + duration / 2,
170
  y=y_pos,
171
+ text=f"{task['batch']}",
172
  showarrow=False,
173
+ font=dict(color=text_color, size=12, family="Arial, bold"),
174
+ ))
175
 
176
+ # Prepare hover data (add traces in batches later)
177
+ hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
178
+
179
+ hover_traces.append(dict(
180
  x=[start_time + duration / 2],
181
  y=[y_pos],
182
  mode='markers',
183
  marker=dict(opacity=0), # Invisible marker
184
  hoverinfo='text',
185
+ text=hover_text,
186
  showlegend=False
187
  ))
188
 
 
191
  tasks_processed += 1
192
  progress_bar.update(1)
193
 
194
+ # Add all shapes at once for better performance
195
+ fig.update_layout(shapes=shapes)
196
+
197
+ # Add all annotations at once
198
+ fig.update_layout(annotations=annotations)
199
+
200
+ # Add all hover traces at once
201
+ for trace in hover_traces:
202
+ fig.add_trace(go.Scatter(**trace))
203
+
204
  # Add custom legend
205
  legend_items = []
206
 
 
215
  for vs in range(max_virtual_stage + 1):
216
  legend_items.append(dict(
217
  name=f"Forward (VS {vs})",
218
+ color=get_color("forward", vs * num_devices, num_devices)
219
  ))
220
  legend_items.append(dict(
221
  name=f"Backward (VS {vs})",
222
+ color=get_color("backward", vs * num_devices, num_devices)
223
  ))
224
 
225
  # If no tasks found, add default legend items
226
  if not legend_items:
227
  legend_items = [
228
+ dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
229
+ dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
230
  ]
231
 
232
  for i, item in enumerate(legend_items):
 
251
  # Adjust the range to ensure there are no empty spaces at the end
252
  x_end = max_time * 1.05 # Add a small margin
253
 
254
+ title_text = "Pipeline Parallelism Schedule"
255
+
256
  fig.update_layout(
257
  yaxis=dict(
258
  tickmode="array",
 
264
  margin=dict(l=50, r=20, t=40, b=40),
265
  plot_bgcolor="white",
266
  title=dict(
267
+ text=title_text,
268
  x=0.5,
269
  y=0.98, # Move title position closer to the top
270
  font=dict(size=20)
 
292
  return fig
293
 
294
 
295
+ # Cache for storing processed schedule data
296
+ _schedule_data_cache = {}
297
+
298
+ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True):
299
  """
300
  Create a Dash app to visualize the pipeline schedule.
301
 
302
  Args:
303
  schedule: Schedule object to visualize
304
+ schedule_type: Type of schedule ("1f1b" or custom description)
305
+ enable_caching: Whether to cache the schedule data and figure
306
  """
307
+ # Process schedule data only once and cache it
308
+ global _schedule_data_cache
309
+ cache_key = id(schedule)
310
 
311
+ if enable_caching and cache_key in _schedule_data_cache:
312
+ schedule_data = _schedule_data_cache[cache_key]
313
+ print("Using cached schedule data")
314
+ else:
315
+ schedule_data = convert_schedule_to_visualization_format(schedule)
316
+ if enable_caching:
317
+ _schedule_data_cache[cache_key] = schedule_data
318
+ print("Cached schedule data")
319
 
320
+ total_tasks = sum(len(tasks) for tasks in schedule_data.values())
321
+ print(f"Total tasks in schedule: {total_tasks}")
322
+
323
+ app = dash.Dash(__name__)
324
+ app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
325
+
326
+ # Create a more informative layout with data size information
327
  app.layout = html.Div([
328
+ html.H1(f"Pipeline Parallelism Visualization - {schedule_type}", style={"textAlign": "center"}),
329
 
330
  html.Div([
331
+ html.P(f"Number of devices: {len(schedule_data)}", style={"display": "inline-block", "marginRight": "20px"}),
332
+ html.P(f"Total tasks: {total_tasks}", style={"display": "inline-block", "marginRight": "20px"}),
333
+ ], style={"marginBottom": "20px"}),
 
 
 
 
 
 
 
334
 
335
  html.Div(id="graph-container", children=[]),
336
 
337
+ dcc.Loading(
338
+ id="loading-graph",
339
+ type="circle",
340
+ children=[
341
+ dcc.Graph(
342
+ id="pipeline-graph",
343
+ config={'displayModeBar': True, 'toImageButtonOptions': {'format': 'png', 'filename': 'pipeline_visualization'}}
344
+ ),
345
+ ]
346
  ),
347
  ])
348
 
349
+ # Cache for storing figure to avoid regenerating it
350
+ figure_cache = {}
351
+
352
  @app.callback(
353
  Output("pipeline-graph", "figure"),
354
  Input("graph-container", "children"),
355
  prevent_initial_call=False,
356
  )
357
  def load_graph(_):
358
+ # Use cached figure if available
359
+ cache_key = f"{id(schedule)}"
360
+ if enable_caching and cache_key in figure_cache:
361
+ print("Using cached figure")
362
+ return figure_cache[cache_key]
363
+
364
+ # Create the figure
365
+ figure = create_pipeline_figure(schedule_data, show_progress=True)
366
+
367
+ # Cache the figure
368
+ if enable_caching:
369
+ figure_cache[cache_key] = figure
370
+ print("Cached figure")
371
+
372
+ return figure
373
 
374
  return app
375
 
 
377
  def visualize_pipeline_parallelism_dash(
378
  schedule: Schedule,
379
  port: int = 8050,
380
+ debug: bool = False,
381
+ enable_caching: bool = True
382
  ):
383
  """
384
  Launch a Dash app to visualize the pipeline schedule interactively.
 
387
  schedule: Schedule object to visualize
388
  port: Port to run the Dash app on
389
  debug: Whether to run the Dash app in debug mode
390
+ enable_caching: Whether to cache schedule data and figures
391
  """
392
+ app = create_dash_app(schedule, enable_caching=enable_caching)
393
  print(f"Starting Dash app on http://localhost:{port}/")
394
  app.run_server(debug=debug, port=port)