File size: 16,227 Bytes
a49be3b
 
 
 
5b28831
a49be3b
5b28831
a49be3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28831
 
 
 
 
 
06107a3
 
5b28831
06107a3
5b28831
 
 
 
 
 
 
06107a3
5b28831
06107a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28831
06107a3
 
5b28831
 
 
 
 
 
 
 
 
 
 
06107a3
 
 
 
5b28831
 
 
 
a49be3b
 
 
 
 
 
 
 
 
 
 
5b28831
a49be3b
5b28831
a49be3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b28831
 
 
 
 
a49be3b
 
 
 
 
 
06107a3
a49be3b
 
 
5b28831
a49be3b
 
 
5b28831
a49be3b
 
06107a3
 
 
 
 
 
 
 
a49be3b
 
 
 
 
 
 
 
 
 
 
 
5b28831
 
a49be3b
 
 
 
 
 
 
 
5b28831
a49be3b
5b28831
 
a49be3b
 
5b28831
a49be3b
5b28831
 
a49be3b
5b28831
 
 
 
a49be3b
 
 
 
 
5b28831
a49be3b
 
 
 
 
 
 
 
5b28831
 
 
 
 
 
 
 
 
 
a49be3b
16ed969
 
 
 
 
 
 
 
 
 
 
 
 
5b28831
16ed969
 
 
5b28831
16ed969
06107a3
 
 
 
 
 
 
 
 
 
16ed969
 
 
 
5b28831
 
06107a3
 
16ed969
a49be3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540a9e8
 
 
 
a49be3b
 
 
 
 
 
 
5b28831
 
a49be3b
 
 
 
 
 
 
 
 
 
 
5b28831
a49be3b
 
 
 
 
16ed969
a49be3b
16ed969
 
06107a3
16ed969
 
 
a49be3b
06107a3
16ed969
a49be3b
 
 
 
 
 
 
 
 
 
 
5b28831
 
 
 
a49be3b
 
 
 
 
06107a3
5b28831
a49be3b
5b28831
 
 
a49be3b
5b28831
 
 
 
 
 
 
 
a49be3b
5b28831
 
 
 
 
 
 
a49be3b
5b28831
a49be3b
 
5b28831
 
 
a49be3b
 
 
5b28831
 
 
 
 
 
 
 
 
a49be3b
 
 
5b28831
 
 
a49be3b
 
 
 
 
 
5b28831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16ed969
a49be3b
 
 
 
 
 
5b28831
06107a3
 
a49be3b
 
 
 
 
 
 
 
5b28831
06107a3
a49be3b
06107a3
a49be3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
from typing import List, Dict
from tqdm import tqdm
from functools import lru_cache

from src.execution_model import Schedule


def convert_schedule_to_visualization_format(schedule: Schedule):
    """
    Converts a Schedule object to the format needed for visualization.
    
    Returns:
        Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
    """
    # Make sure all operations have start and end times
    for op in schedule.ops.values():
        if op.start_time is None or op.end_time is None:
            raise ValueError("Operations must have start and end times. Run ScheduleExecutor.execute() first.")
    
    visualization_data = {}
    
    # Organize operations by device
    for device_id, device_queue in enumerate(schedule.dev_queues):
        visualization_data[device_id] = []
        
        for op in device_queue.ops:
            visualization_data[device_id].append({
                "type": op.op_type,
                "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
                "stage": op.stage_id,
                "start_time": op.start_time,
                "duration": op.end_time - op.start_time
            })
    
    return visualization_data


# Cache the color calculation as it's repeatedly called with the same parameters
@lru_cache(maxsize=128)
def get_color(op_type: str, stage_id: int, num_devices: int):
    # Color palettes for different virtual stages
    forward_colors = [
        "royalblue",      # Stage 0
        "cornflowerblue", # Stage 1
        "dodgerblue",     # Stage 2
        "steelblue",      # Stage 3
        "lightskyblue",   # Stage 4
        "deepskyblue",    # Stage 5
        "mediumblue",     # Stage 6
        "mediumslateblue",# Stage 7
        "slateblue",      # Stage 8
        "darkslateblue"   # Stage 9
    ]
    
    # Updated to orange/brown palette for backward operations
    backward_colors = [
        "darkorange",     # Stage 0
        "orange",         # Stage 1
        "sandybrown",     # Stage 2
        "peru",           # Stage 3
        "chocolate",      # Stage 4
        "sienna",         # Stage 5
        "saddlebrown",    # Stage 6
        "brown",          # Stage 7
        "darkgoldenrod",  # Stage 8
        "goldenrod"       # Stage 9
    ]
    
    # Updated to teal/turquoise palette for backward_D operations
    backward_d_colors = [
        "mediumaquamarine", # Stage 8
        "cadetblue",      # Stage 2
        "lightseagreen",  # Stage 6
        "cyan",           # Stage 0
        "teal",           # Stage 1
        "mediumturquoise",# Stage 3
        "turquoise",      # Stage 4
        "aquamarine",     # Stage 5
        "darkturquoise",  # Stage 7
        "paleturquoise"   # Stage 9
    ]
    
    # Updated to green palette for backward_W operations
    backward_w_colors = [
        "limegreen",      # Stage 2
        "forestgreen",    # Stage 0
        "green",          # Stage 1
        "seagreen",       # Stage 3
        "mediumseagreen", # Stage 4
        "springgreen",    # Stage 5
        "mediumspringgreen", # Stage 6
        "palegreen",      # Stage 7
        "lightgreen",     # Stage 8
        "darkseagreen"    # Stage 9
    ]

    virtual_stage = stage_id // num_devices

    # If virtual_stage is beyond our color list, cycle through the colors
    color_index = virtual_stage % len(forward_colors)

    if op_type == "forward":
        return forward_colors[color_index]
    elif op_type == "backward":
        return backward_colors[color_index]
    elif op_type == "backward_D":
        return backward_d_colors[color_index]
    elif op_type == "backward_W":
        return backward_w_colors[color_index]
    else:
        raise ValueError(f"Invalid operation type: {op_type}")


def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True):
    """
    Create a Plotly figure for pipeline parallelism scheduling.

    Args:
        schedule_data: Dictionary mapping device IDs to lists of tasks (converted from Schedule)
        max_time: Optional maximum time to display
        show_progress: Whether to show a progress bar
    """
    # Find the number of devices
    num_devices = len(schedule_data)
    
    empty_color = "whitesmoke"
    
    # Find the maximum time in the schedule if not provided
    if max_time is None:
        max_time = 0
        for device in schedule_data:
            for task in schedule_data[device]:
                end_time = task["start_time"] + task["duration"]
                if end_time > max_time:
                    max_time = end_time

    # Create a figure
    fig = go.Figure()

    # Initialize progress tracking
    total_tasks = sum(len(tasks) for tasks in schedule_data.values())
    tasks_processed = 0

    if show_progress:
        progress_bar = tqdm(total=total_tasks + num_devices + 3, desc="Creating visualization")

    # Create a custom y-axis with no gaps between devices
    y_spacing = 1.0  # Use 1.0 for no gaps

    # Batch processing for increased performance
    shapes = []
    annotations = []
    hover_traces = []

    # Add rectangles for each task
    for device_idx, device in enumerate(schedule_data):
        device_idx_reversed = num_devices - device_idx - 1
        
        # Sort tasks by start time to ensure correct rendering
        sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])

        for task in sorted_tasks:
            # Determine task color and text color
            if task["type"] == "forward":
                color = get_color(task["type"], task["stage"], num_devices)
                text_color = "white"
                name = "Forward"
            elif task["type"] == "backward":
                color = get_color(task["type"], task["stage"], num_devices)
                text_color = "black"
                name = "Backward"
            elif task["type"] == "backward_D":
                color = get_color(task["type"], task["stage"], num_devices)
                text_color = "black"
                name = "Backward (Grad)"
            elif task["type"] == "backward_W":
                color = get_color(task["type"], task["stage"], num_devices)
                text_color = "black"
                name = "Backward (Weight)"
            else:
                color = empty_color
                text_color = "black"
                name = "Unknown"

            # Add rectangle for the task
            start_time = task["start_time"]
            duration = task["duration"]
            
            # Calculate y positions with no gaps
            y_pos = device_idx_reversed * y_spacing
            
            # Create rectangle using shape (batch-add later)
            shapes.append(dict(
                type="rect",
                x0=start_time,
                y0=y_pos - 0.5,
                x1=start_time + duration,
                y1=y_pos + 0.5,
                line=dict(color="black", width=0.5),
                fillcolor=color,
                layer="above",
            ))
            
            # Add batch number text (batch-add later)
            annotations.append(dict(
                x=start_time + duration / 2,
                y=y_pos,
                text=f"{task['batch']}",  
                showarrow=False,
                font=dict(color=text_color, size=12, family="Arial, bold"),
            ))
            
            # Prepare hover data (add traces in batches later)
            hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
            
            hover_traces.append(dict(
                x=[start_time + duration / 2],
                y=[y_pos],
                mode='markers',
                marker=dict(opacity=0),  # Invisible marker
                hoverinfo='text',
                text=hover_text,
                showlegend=False
            ))
            
            # Update progress
            if show_progress:
                tasks_processed += 1
                progress_bar.update(1)

    # Add all shapes at once for better performance
    fig.update_layout(shapes=shapes)
    
    # Add all annotations at once
    fig.update_layout(annotations=annotations)
    
    # Add all hover traces at once
    for trace in hover_traces:
        fig.add_trace(go.Scatter(**trace))

    # Add custom legend
    legend_items = []
    
    # Find the maximum virtual stage in the data
    max_virtual_stage = 0
    for device in schedule_data:
        for task in schedule_data[device]:
            virtual_stage = task["stage"] // num_devices
            max_virtual_stage = max(max_virtual_stage, virtual_stage)
    
    # Add forward and backward items for each virtual stage
    for vs in range(max_virtual_stage + 1):
        legend_items.append(dict(
            name=f"Forward (VS {vs})", 
            color=get_color("forward", vs * num_devices, num_devices)
        ))
        legend_items.append(dict(
            name=f"Backward (VS {vs})", 
            color=get_color("backward", vs * num_devices, num_devices)
        ))
        # Add entries for split backward operations if this is a zb1p schedule
        if any(task["type"] in ["backward_D", "backward_W"] for device in schedule_data for task in schedule_data[device]):
            legend_items.append(dict(
                name=f"Backward Grad (VS {vs})", 
                color=get_color("backward_D", vs * num_devices, num_devices)
            ))
            legend_items.append(dict(
                name=f"Backward Weight (VS {vs})", 
                color=get_color("backward_W", vs * num_devices, num_devices)
            ))
    
    # If no tasks found, add default legend items
    if not legend_items:
        legend_items = [
            dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
            dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
            dict(name="Backward Grad (VS 0)", color=get_color("backward_D", 0, num_devices)),
            dict(name="Backward Weight (VS 0)", color=get_color("backward_W", 0, num_devices)),
        ]
    
    for i, item in enumerate(legend_items):
        fig.add_trace(go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(size=10, color=item['color']),
            name=item['name'],
            showlegend=True
        ))
        if show_progress and i < len(legend_items) - 1:
            progress_bar.update(1)

    # Set axis properties
    device_labels = [f"Device {i}" for i in range(num_devices)]
    # Modify the ordering to put Device 1 at the top, then Device 0, then the rest
    if num_devices >= 2:
        # Move Device 1 to the top, followed by Device 0
        device_labels = [device_labels[1], device_labels[0]] + device_labels[2:] if num_devices > 1 else device_labels
    
    # Calculate tick positions with no gaps
    tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
    
    # Adjust the range to ensure there are no empty spaces at the end
    x_end = max_time * 1.05  # Add a small margin

    title_text = "Pipeline Parallelism Schedule"

    fig.update_layout(
        yaxis=dict(
            tickmode="array",
            tickvals=tick_positions,
            ticktext=device_labels,
            showgrid=False,
            zeroline=False,
        ),
        margin=dict(l=50, r=20, t=40, b=40),
        plot_bgcolor="white",
        title=dict(
            text=title_text,
            x=0.5,
            y=0.98,  # Move title position closer to the top
            font=dict(size=20)
        ),
        legend=dict(
            orientation="v",  # Changed from horizontal to vertical
            yanchor="top",
            y=1.02,  # Position at the top
            xanchor="right",
            x=1.20,   # Position further to the right to accommodate more items
            title=dict(text="<b>Operation Types:</b>"),
            itemsizing="constant",
            tracegroupgap=0
        ),
        width=2000,  # Increase width to accommodate the expanded legend
        height=400,  # Maintain current height
        bargap=0,
        bargroupgap=0,
    )

    if show_progress:
        progress_bar.update(1)
        progress_bar.close()

    return fig


# Cache for storing processed schedule data
_schedule_data_cache = {}

def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True):
    """
    Create a Dash app to visualize the pipeline schedule.
    
    Args:
        schedule: Schedule object to visualize
        schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
        enable_caching: Whether to cache the schedule data and figure
    """
    # Process schedule data only once and cache it
    global _schedule_data_cache
    cache_key = id(schedule)
    
    if enable_caching and cache_key in _schedule_data_cache:
        schedule_data = _schedule_data_cache[cache_key]
        print("Using cached schedule data")
    else:
        schedule_data = convert_schedule_to_visualization_format(schedule)
        if enable_caching:
            _schedule_data_cache[cache_key] = schedule_data
            print("Cached schedule data")
    
    total_tasks = sum(len(tasks) for tasks in schedule_data.values())
    print(f"Total tasks in schedule: {total_tasks}")

    app = dash.Dash(__name__)
    app.title = f"Pipeline Parallelism Visualization - {schedule_type}"

    # Create a more informative layout with data size information
    app.layout = html.Div([
        html.H1(f"Pipeline Parallelism Visualization - {schedule_type}", style={"textAlign": "center"}),
        
        html.Div([
            html.P(f"Number of devices: {len(schedule_data)}", style={"display": "inline-block", "marginRight": "20px"}),
            html.P(f"Total tasks: {total_tasks}", style={"display": "inline-block", "marginRight": "20px"}),
        ], style={"marginBottom": "20px"}),
        
        html.Div(id="graph-container", children=[]),
        
        dcc.Loading(
            id="loading-graph",
            type="circle",
            children=[
                dcc.Graph(
                    id="pipeline-graph",
                    config={'displayModeBar': True, 'toImageButtonOptions': {'format': 'png', 'filename': 'pipeline_visualization'}}
                ),
            ]
        ),
    ])
    
    # Cache for storing figure to avoid regenerating it
    figure_cache = {}
    
    @app.callback(
        Output("pipeline-graph", "figure"),
        Input("graph-container", "children"),
        prevent_initial_call=False,
    )
    def load_graph(_):
        # Use cached figure if available
        cache_key = f"{id(schedule)}"
        if enable_caching and cache_key in figure_cache:
            print("Using cached figure")
            return figure_cache[cache_key]
        
        # Create the figure
        figure = create_pipeline_figure(schedule_data, show_progress=True)
        
        # Cache the figure
        if enable_caching:
            figure_cache[cache_key] = figure
            print("Cached figure")
            
        return figure

    return app


def visualize_pipeline_parallelism_dash(
    schedule: Schedule,
    port: int = 8050,
    debug: bool = False,
    enable_caching: bool = True,
    schedule_type="1f1b"
):
    """
    Launch a Dash app to visualize the pipeline schedule interactively.
    
    Args:
        schedule: Schedule object to visualize
        port: Port to run the Dash app on
        debug: Whether to run the Dash app in debug mode
        enable_caching: Whether to cache schedule data and figures
        schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
    """
    app = create_dash_app(schedule, schedule_type=schedule_type, enable_caching=enable_caching)
    print(f"Starting Dash app on http://localhost:{port}/")
    app.run_server(debug=debug, port=port)