Victarry commited on
Commit
e178784
·
1 Parent(s): a5a3887

Add dash backend visualizer.

Browse files
README-dash-visualizer.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pipeline Parallelism Dash Visualizer
2
+
3
+ This is an interactive Dash-based visualizer for pipeline parallelism scheduling, complementing the existing Matplotlib-based visualization.
4
+
5
+ ## Features
6
+
7
+ - **Static image generation** similar to the Matplotlib version
8
+ - **Interactive web-based visualization** with Dash
9
+ - **Download functionality** to save the visualization as PNG
10
+ - **Progress indication** during figure creation and image generation
11
+ - **Compatible API** with the existing visualizer
12
+
13
+ ## Installation
14
+
15
+ Install the required dependencies:
16
+
17
+ ```bash
18
+ pip install -r requirements-dash.txt
19
+ ```
20
+
21
+ ## Usage
22
+
23
+ ### From Python
24
+
25
+ ```python
26
+ from pipeline import create_1f1b_schedule
27
+ from dash_visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
28
+
29
+ # Create a schedule
30
+ schedule = create_1f1b_schedule(
31
+ num_stages=4,
32
+ num_batches=8,
33
+ forward_times=[1.0, 1.0, 1.0, 1.0],
34
+ backward_times=[2.0, 2.0, 2.0, 2.0],
35
+ )
36
+
37
+ # Generate a static image
38
+ save_pipeline_visualization_plotly(
39
+ schedule=schedule,
40
+ schedule_type="1f1b",
41
+ output_file="pipeline_plotly.png"
42
+ )
43
+
44
+ # OR launch an interactive Dash app
45
+ visualize_pipeline_parallelism_dash(
46
+ schedule=schedule,
47
+ schedule_type="1f1b",
48
+ port=8050,
49
+ debug=False
50
+ )
51
+ ```
52
+
53
+ ### Using the Command Line
54
+
55
+ You can use the updated command line interface:
56
+
57
+ ```bash
58
+ # Generate a static image with Dash/Plotly
59
+ python pipeline.py --visualizer dash --output-file pipeline_viz.png
60
+
61
+ # Launch an interactive Dash app
62
+ python pipeline.py --visualizer dash-interactive
63
+
64
+ # Use the original Matplotlib visualizer
65
+ python pipeline.py --visualizer matplotlib
66
+ ```
67
+
68
+ You can also use the dash_visualizer.py script directly for testing:
69
+
70
+ ```bash
71
+ # Generate a static image
72
+ python dash_visualizer.py --output test_viz.png
73
+
74
+ # Launch an interactive app
75
+ python dash_visualizer.py --interactive
76
+ ```
77
+
78
+ ## Differences from Matplotlib Visualizer
79
+
80
+ The Dash-based visualizer provides all the same visual elements as the Matplotlib version:
81
+ - Color-coded rectangles for forward, backward, and optimizer operations
82
+ - Batch numbers displayed inside each rectangle
83
+ - Device labels on the y-axis
84
+ - Clear legend
85
+
86
+ Additional features:
87
+ - Interactive web interface
88
+ - Hovering over elements to see details
89
+ - Download button to save the visualization
90
+ - Progress bars for tracking visualization creation
91
+ - Responsive layout that works well on different screen sizes
dash_visualizer.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash
2
+ from dash import dcc, html
3
+ from dash.dependencies import Input, Output, State
4
+ import plotly.graph_objects as go
5
+ import numpy as np
6
+ from typing import List, Dict, Literal
7
+ from tqdm import tqdm
8
+ import time
9
+
10
+
11
+ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_progress=True):
12
+ """
13
+ Create a Plotly figure for pipeline parallelism scheduling.
14
+
15
+ Args:
16
+ schedule: Dictionary mapping device IDs to lists of tasks.
17
+ Each task is a dictionary with keys:
18
+ - 'type': 'forward', 'backward', or 'optimizer'
19
+ - 'batch': batch number
20
+ - 'start_time': start time of the task
21
+ - 'duration': duration of the task
22
+ max_time: Optional maximum time to display
23
+ show_progress: Whether to show a progress bar
24
+ """
25
+ # Colors for task types
26
+ forward_color = "royalblue"
27
+ backward_color = "sandybrown"
28
+ optimizer_color = "#FFEFCF"
29
+ empty_color = "whitesmoke"
30
+
31
+ # Find the number of stages (devices)
32
+ num_stages = len(schedule)
33
+
34
+ # Find the maximum time in the schedule if not provided
35
+ if max_time is None:
36
+ max_time = 0
37
+ for device in schedule:
38
+ for task in schedule[device]:
39
+ end_time = task["start_time"] + task["duration"]
40
+ if end_time > max_time:
41
+ max_time = end_time
42
+
43
+ # Create a figure
44
+ fig = go.Figure()
45
+
46
+ # Initialize progress tracking
47
+ total_tasks = sum(len(tasks) for tasks in schedule.values())
48
+ tasks_processed = 0
49
+
50
+ if show_progress:
51
+ progress_bar = tqdm(total=total_tasks + num_stages + 3, desc="Creating visualization")
52
+
53
+ # Add background for empty cells
54
+ for device_idx in range(num_stages):
55
+ device_idx_reversed = num_stages - device_idx - 1 # Reverse for plotting
56
+ fig.add_trace(go.Scatter(
57
+ x=[0, max_time],
58
+ y=[device_idx_reversed, device_idx_reversed],
59
+ mode='lines',
60
+ line=dict(color='lightgray', width=0.5),
61
+ showlegend=False,
62
+ hoverinfo='none'
63
+ ))
64
+ if show_progress:
65
+ progress_bar.update(1)
66
+
67
+ # Add rectangles for each task
68
+ for device_idx, device in enumerate(schedule):
69
+ device_idx_reversed = num_stages - device_idx - 1
70
+
71
+ for task in schedule[device]:
72
+ # Determine task color and text color
73
+ if task["type"] == "forward":
74
+ color = forward_color
75
+ text_color = "white"
76
+ name = "Forward"
77
+ elif task["type"] == "backward":
78
+ color = backward_color
79
+ text_color = "black"
80
+ name = "Backward"
81
+ else: # optimizer or any other type
82
+ color = optimizer_color
83
+ text_color = "black"
84
+ name = "Optimizer step"
85
+
86
+ # Add rectangle for the task
87
+ start_time = task["start_time"]
88
+ duration = task["duration"]
89
+
90
+ # Create rectangle using shape
91
+ fig.add_shape(
92
+ type="rect",
93
+ x0=start_time,
94
+ y0=device_idx_reversed - 0.4,
95
+ x1=start_time + duration,
96
+ y1=device_idx_reversed + 0.4,
97
+ line=dict(color="black", width=0.5),
98
+ fillcolor=color,
99
+ layer="above",
100
+ )
101
+
102
+ # Add batch number text
103
+ fig.add_annotation(
104
+ x=start_time + duration / 2,
105
+ y=device_idx_reversed,
106
+ text=str(task["batch"]),
107
+ showarrow=False,
108
+ font=dict(color=text_color, size=10, family="Arial, bold"),
109
+ )
110
+
111
+ # Update progress
112
+ if show_progress:
113
+ tasks_processed += 1
114
+ progress_bar.update(1)
115
+
116
+ # Add custom legend
117
+ legend_items = [
118
+ dict(name="Forward", color=forward_color),
119
+ dict(name="Backward", color=backward_color),
120
+ dict(name="Optimizer step", color=optimizer_color)
121
+ ]
122
+
123
+ for i, item in enumerate(legend_items):
124
+ fig.add_trace(go.Scatter(
125
+ x=[None],
126
+ y=[None],
127
+ mode='markers',
128
+ marker=dict(size=10, color=item['color']),
129
+ name=item['name'],
130
+ showlegend=True
131
+ ))
132
+ if show_progress and i < len(legend_items) - 1:
133
+ progress_bar.update(1)
134
+
135
+ # Set axis properties
136
+ device_labels = [f"Device {i+1}" for i in range(num_stages)]
137
+ device_labels.reverse() # Reverse to put Device 1 at the top
138
+
139
+ fig.update_layout(
140
+ xaxis=dict(
141
+ showticklabels=False,
142
+ showgrid=False,
143
+ zeroline=False,
144
+ title="Time →",
145
+ range=[0, max_time + 0.5]
146
+ ),
147
+ yaxis=dict(
148
+ tickmode="array",
149
+ tickvals=list(range(num_stages)),
150
+ ticktext=device_labels,
151
+ showgrid=False,
152
+ zeroline=False,
153
+ range=[-0.5, num_stages - 0.5]
154
+ ),
155
+ margin=dict(l=50, r=50, t=50, b=50),
156
+ plot_bgcolor="white",
157
+ legend=dict(
158
+ orientation="h",
159
+ yanchor="bottom",
160
+ y=-0.2,
161
+ xanchor="center",
162
+ x=0.5
163
+ )
164
+ )
165
+
166
+ if show_progress:
167
+ progress_bar.update(1) # Final update for layout
168
+ progress_bar.close()
169
+
170
+ return fig
171
+
172
+
173
+ def create_dash_app(schedule: Dict[int, List[Dict]], schedule_type="1f1b"):
174
+ """
175
+ Create a Dash app for interactive visualization of pipeline scheduling.
176
+
177
+ Args:
178
+ schedule: Dictionary mapping device IDs to lists of tasks
179
+ schedule_type: Type of scheduling algorithm used
180
+ """
181
+ app = dash.Dash(__name__, title="Pipeline Parallelism Visualization")
182
+
183
+ app.layout = html.Div([
184
+ html.H1(f"Pipeline Parallelism Visualization ({schedule_type.upper()})",
185
+ style={'textAlign': 'center'}),
186
+
187
+ html.Div(id="loading-container", children=[
188
+ dcc.Loading(
189
+ id="loading-graph",
190
+ type="circle",
191
+ children=[
192
+ html.Div(id="graph-container", children=[
193
+ dcc.Graph(
194
+ id='pipeline-graph',
195
+ style={'height': '600px'}
196
+ )
197
+ ])
198
+ ]
199
+ )
200
+ ]),
201
+
202
+ html.Div([
203
+ html.Button("Download PNG", id="btn-download",
204
+ style={'margin': '10px'}),
205
+ dcc.Download(id="download-image")
206
+ ], style={'textAlign': 'center', 'marginTop': '20px'})
207
+ ])
208
+
209
+ @app.callback(
210
+ Output("pipeline-graph", "figure"),
211
+ Input("graph-container", "children"),
212
+ prevent_initial_call=False,
213
+ )
214
+ def load_graph(_):
215
+ # Create the figure when the app loads
216
+ return create_pipeline_figure(schedule, show_progress=True)
217
+
218
+ @app.callback(
219
+ Output("download-image", "data"),
220
+ Input("btn-download", "n_clicks"),
221
+ prevent_initial_call=True,
222
+ )
223
+ def download_image(n_clicks):
224
+ # Show progress in terminal for downloads
225
+ fig = create_pipeline_figure(schedule, show_progress=True)
226
+ img_bytes = fig.to_image(format="png", scale=3)
227
+ return dict(
228
+ content=img_bytes,
229
+ filename="pipeline_visualization.png"
230
+ )
231
+
232
+ return app
233
+
234
+
235
+ def visualize_pipeline_parallelism_dash(
236
+ schedule: Dict[int, List[Dict]],
237
+ schedule_type: Literal["simple", "1f1b"] = "1f1b",
238
+ port: int = 8050,
239
+ debug: bool = False
240
+ ):
241
+ """
242
+ Create an interactive Dash visualization for pipeline parallelism scheduling.
243
+
244
+ Args:
245
+ schedule: Dictionary mapping device IDs to lists of tasks
246
+ schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
247
+ port: Port number to run the Dash app
248
+ debug: Whether to run the app in debug mode
249
+ """
250
+ app = create_dash_app(schedule, schedule_type)
251
+ print(f"Starting Dash app on http://localhost:{port}/")
252
+ app.run_server(debug=debug, port=port)
253
+
254
+
255
+ def save_pipeline_visualization_plotly(
256
+ schedule: Dict[int, List[Dict]],
257
+ schedule_type: Literal["simple", "1f1b"] = "1f1b",
258
+ output_file: str = "pipeline_visualization_plotly.png",
259
+ ):
260
+ """
261
+ Save a static Plotly visualization of pipeline parallelism scheduling.
262
+
263
+ Args:
264
+ schedule: Dictionary mapping device IDs to lists of tasks
265
+ schedule_type: Type of scheduling algorithm used
266
+ output_file: Path to save the visualization
267
+ """
268
+ print(f"Creating visualization for {len(schedule)} devices...")
269
+ fig = create_pipeline_figure(schedule, show_progress=True)
270
+
271
+ # Update layout for static image
272
+ fig.update_layout(
273
+ title=f"Pipeline Parallelism Visualization ({schedule_type.upper()})",
274
+ title_x=0.5
275
+ )
276
+
277
+ print(f"Saving image to {output_file}...")
278
+ # Save as image
279
+ fig.write_image(output_file, scale=3)
280
+ print(f"Visualization saved to {output_file}")
281
+
282
+
283
+ if __name__ == "__main__":
284
+ # Example usage
285
+ import argparse
286
+ from pipeline import create_1f1b_schedule
287
+
288
+ parser = argparse.ArgumentParser(description="Pipeline Parallelism Visualizer")
289
+ parser.add_argument("--num-stages", type=int, default=4, help="Number of pipeline stages")
290
+ parser.add_argument("--num-batches", type=int, default=8, help="Number of microbatches")
291
+ parser.add_argument("--interactive", action="store_true", help="Run interactive Dash app")
292
+ parser.add_argument("--port", type=int, default=8050, help="Port for Dash app")
293
+ parser.add_argument("--output", type=str, default="pipeline_visualization_plotly.png", help="Output file for static image")
294
+ args = parser.parse_args()
295
+
296
+ # Create an example schedule
297
+ forward_times = [1.0] * args.num_stages
298
+ backward_times = [2.0] * args.num_stages
299
+
300
+ schedule = create_1f1b_schedule(
301
+ num_stages=args.num_stages,
302
+ num_batches=args.num_batches,
303
+ forward_times=forward_times,
304
+ backward_times=backward_times,
305
+ )
306
+
307
+ if args.interactive:
308
+ visualize_pipeline_parallelism_dash(schedule, port=args.port)
309
+ else:
310
+ save_pipeline_visualization_plotly(schedule, output_file=args.output)
pipeline.py CHANGED
@@ -9,6 +9,11 @@ from typing import List, Tuple, Dict, Literal
9
 
10
  # Import visualization function from the new module
11
  from visualizer import visualize_pipeline_parallelism
 
 
 
 
 
12
 
13
 
14
  def create_1f1b_schedule(
@@ -210,6 +215,7 @@ def get_bubble_rate(schedule: Dict[int, List[Dict]]):
210
  if end_time > max_time:
211
  max_time = end_time
212
 
 
213
  total_execution_time = max_time * num_stages
214
 
215
  total_computation_time = 0
@@ -325,6 +331,9 @@ def parse_args():
325
  help="Time for point-to-point communication between stages",
326
  )
327
 
 
 
 
328
  return parser.parse_args()
329
 
330
 
@@ -447,9 +456,24 @@ def main():
447
 
448
  # Create visualization unless --no-visualization is specified
449
  if not args.no_visualization:
450
- visualize_pipeline_parallelism(
451
- schedule=schedule, schedule_type="1f1b", output_file=output_file
452
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
  # Analyze the schedule
455
  bubble_rate = get_bubble_rate(schedule)
 
9
 
10
  # Import visualization function from the new module
11
  from visualizer import visualize_pipeline_parallelism
12
+ try:
13
+ from dash_visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
14
+ DASH_AVAILABLE = True
15
+ except ImportError:
16
+ DASH_AVAILABLE = False
17
 
18
 
19
  def create_1f1b_schedule(
 
215
  if end_time > max_time:
216
  max_time = end_time
217
 
218
+ print(f"Max time: {max_time}")
219
  total_execution_time = max_time * num_stages
220
 
221
  total_computation_time = 0
 
331
  help="Time for point-to-point communication between stages",
332
  )
333
 
334
+ parser.add_argument("--visualizer", choices=["matplotlib", "dash", "dash-interactive"],
335
+ default="matplotlib", help="Visualization library to use")
336
+
337
  return parser.parse_args()
338
 
339
 
 
456
 
457
  # Create visualization unless --no-visualization is specified
458
  if not args.no_visualization:
459
+ if args.visualizer == "matplotlib" or not DASH_AVAILABLE:
460
+ if not DASH_AVAILABLE and args.visualizer in ["dash", "dash-interactive"]:
461
+ print("Warning: Dash not available. Falling back to matplotlib.")
462
+ visualize_pipeline_parallelism(
463
+ schedule=schedule, schedule_type="1f1b", output_file=output_file
464
+ )
465
+ elif args.visualizer == "dash":
466
+ # Get output file name without extension to use the appropriate extension
467
+ output_base = os.path.splitext(output_file)[0]
468
+ output_dash = f"{output_base}_plotly.png"
469
+ save_pipeline_visualization_plotly(
470
+ schedule=schedule, schedule_type="1f1b", output_file=output_dash
471
+ )
472
+ elif args.visualizer == "dash-interactive":
473
+ print("Using Dash interactive visualization")
474
+ visualize_pipeline_parallelism_dash(
475
+ schedule=schedule, schedule_type="1f1b", port=8050, debug=False
476
+ )
477
 
478
  # Analyze the schedule
479
  bubble_rate = get_bubble_rate(schedule)
requirements-dash.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ dash==2.13.0
2
+ plotly==5.18.0
3
+ numpy
4
+ kaleido # For static image export
5
+ tqdm # For progress bars