Victarry commited on
Commit
16ed969
·
1 Parent(s): a49be3b

Update visualizer.

Browse files
Files changed (2) hide show
  1. main.py +1 -1
  2. src/visualizer.py +67 -44
main.py CHANGED
@@ -1,6 +1,6 @@
1
  from src.execution_model import ScheduleConfig, ScheduleExecutor
2
  from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
3
- from src.visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
4
  import hydra
5
  from omegaconf import DictConfig, OmegaConf
6
 
 
1
  from src.execution_model import ScheduleConfig, ScheduleExecutor
2
  from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
3
+ from src.visualizer import visualize_pipeline_parallelism_dash
4
  import hydra
5
  from omegaconf import DictConfig, OmegaConf
6
 
src/visualizer.py CHANGED
@@ -55,24 +55,42 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
55
  empty_color = "whitesmoke"
56
  # Colors for task types
57
  def get_color(op_type: str, stage_id: int):
58
- # Base colors
59
- forward_base_color = "royalblue"
60
- backward_base_color = "lightgreen" # Changed from sandybrown to match your visualization
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  virtual_stage = stage_id // num_devices
63
 
 
 
 
64
  if op_type == "forward":
65
- if virtual_stage == 0:
66
- return forward_base_color
67
- else:
68
- # Lighter shade for virtual_stage > 0
69
- return "lightskyblue"
70
  elif op_type == "backward":
71
- if virtual_stage == 0:
72
- return backward_base_color
73
- else:
74
- # Lighter shade for virtual_stage > 0
75
- return "lightseagreen"
76
  else:
77
  raise ValueError(f"Invalid operation type: {op_type}")
78
 
@@ -165,10 +183,32 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
165
  progress_bar.update(1)
166
 
167
  # Add custom legend
168
- legend_items = [
169
- dict(name="Forward", color=get_color("forward", 0)),
170
- dict(name="Backward", color=get_color("backward", 0)),
171
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  for i, item in enumerate(legend_items):
174
  fig.add_trace(go.Scatter(
@@ -209,14 +249,17 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
209
  font=dict(size=20)
210
  ),
211
  legend=dict(
212
- orientation="h",
213
  yanchor="top",
214
- y=-0.1, # Position below the plot
215
- xanchor="center",
216
- x=0.5
 
 
 
217
  ),
218
- width=1600,
219
- height=400, # Reduce height to make the visualization more compact
220
  bargap=0,
221
  bargroupgap=0,
222
  )
@@ -285,7 +328,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
285
  def load_graph(_):
286
  # Create the figure when the app loads
287
  return create_pipeline_figure(schedule_data, show_progress=True)
288
-
289
  @app.callback(
290
  Output("download-image", "data"),
291
  Input("btn-download", "n_clicks"),
@@ -326,23 +369,3 @@ def visualize_pipeline_parallelism_dash(
326
  app = create_dash_app(schedule)
327
  print(f"Starting Dash app on http://localhost:{port}/")
328
  app.run_server(debug=debug, port=port)
329
-
330
-
331
- def save_pipeline_visualization_plotly(
332
- schedule: Schedule,
333
- output_file: str = "pipeline_visualization_plotly.png",
334
- ):
335
- """
336
- Save a static image of the pipeline schedule visualization.
337
-
338
- Args:
339
- schedule: Schedule object to visualize
340
- output_file: Path to save the image to
341
- """
342
- schedule_data = convert_schedule_to_visualization_format(schedule)
343
- fig = create_pipeline_figure(schedule_data, show_progress=True)
344
-
345
- print(f"Saving visualization to {output_file}...")
346
- fig.write_image(output_file, width=1600, height=400, scale=2)
347
- print(f"Visualization saved to {output_file}")
348
-
 
55
  empty_color = "whitesmoke"
56
  # Colors for task types
57
  def get_color(op_type: str, stage_id: int):
58
+ # Color palettes for different virtual stages
59
+ forward_colors = [
60
+ "royalblue", # Stage 0
61
+ "lightskyblue", # Stage 1
62
+ "cornflowerblue", # Stage 2
63
+ "steelblue", # Stage 3
64
+ "dodgerblue", # Stage 4
65
+ "deepskyblue", # Stage 5
66
+ "mediumblue", # Stage 6
67
+ "mediumslateblue",# Stage 7
68
+ "slateblue", # Stage 8
69
+ "darkslateblue" # Stage 9
70
+ ]
71
 
72
+ backward_colors = [
73
+ "lightgreen", # Stage 0
74
+ "mediumseagreen", # Stage 1
75
+ "seagreen", # Stage 2
76
+ "lightseagreen", # Stage 3
77
+ "mediumaquamarine", # Stage 4
78
+ "mediumspringgreen", # Stage 5
79
+ "springgreen", # Stage 6
80
+ "palegreen", # Stage 7
81
+ "limegreen", # Stage 8
82
+ "forestgreen" # Stage 9
83
+ ]
84
+
85
  virtual_stage = stage_id // num_devices
86
 
87
+ # If virtual_stage is beyond our color list, cycle through the colors
88
+ color_index = virtual_stage % len(forward_colors)
89
+
90
  if op_type == "forward":
91
+ return forward_colors[color_index]
 
 
 
 
92
  elif op_type == "backward":
93
+ return backward_colors[color_index]
 
 
 
 
94
  else:
95
  raise ValueError(f"Invalid operation type: {op_type}")
96
 
 
183
  progress_bar.update(1)
184
 
185
  # Add custom legend
186
+ legend_items = []
187
+
188
+ # Find the maximum virtual stage in the data
189
+ max_virtual_stage = 0
190
+ for device in schedule_data:
191
+ for task in schedule_data[device]:
192
+ virtual_stage = task["stage"] // num_devices
193
+ max_virtual_stage = max(max_virtual_stage, virtual_stage)
194
+
195
+ # Add forward and backward items for each virtual stage
196
+ for vs in range(max_virtual_stage + 1):
197
+ legend_items.append(dict(
198
+ name=f"Forward (VS {vs})",
199
+ color=get_color("forward", vs * num_devices)
200
+ ))
201
+ legend_items.append(dict(
202
+ name=f"Backward (VS {vs})",
203
+ color=get_color("backward", vs * num_devices)
204
+ ))
205
+
206
+ # If no tasks found, add default legend items
207
+ if not legend_items:
208
+ legend_items = [
209
+ dict(name="Forward (VS 0)", color=get_color("forward", 0)),
210
+ dict(name="Backward (VS 0)", color=get_color("backward", 0)),
211
+ ]
212
 
213
  for i, item in enumerate(legend_items):
214
  fig.add_trace(go.Scatter(
 
249
  font=dict(size=20)
250
  ),
251
  legend=dict(
252
+ orientation="v", # Changed from horizontal to vertical
253
  yanchor="top",
254
+ y=1.02, # Position at the top
255
+ xanchor="right",
256
+ x=1.15, # Position to the right of the plot
257
+ title=dict(text="<b>Operation Types:</b>"),
258
+ itemsizing="constant",
259
+ tracegroupgap=0
260
  ),
261
+ width=1800, # Increase width to accommodate the legend
262
+ height=400, # Maintain current height
263
  bargap=0,
264
  bargroupgap=0,
265
  )
 
328
  def load_graph(_):
329
  # Create the figure when the app loads
330
  return create_pipeline_figure(schedule_data, show_progress=True)
331
+
332
  @app.callback(
333
  Output("download-image", "data"),
334
  Input("btn-download", "n_clicks"),
 
369
  app = create_dash_app(schedule)
370
  print(f"Starting Dash app on http://localhost:{port}/")
371
  app.run_server(debug=debug, port=port)