Spaces:
Running
Running
Update visualizer.
Browse files- main.py +1 -1
- src/visualizer.py +67 -44
main.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from src.execution_model import ScheduleConfig, ScheduleExecutor
|
2 |
from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
|
3 |
-
from src.visualizer import visualize_pipeline_parallelism_dash
|
4 |
import hydra
|
5 |
from omegaconf import DictConfig, OmegaConf
|
6 |
|
|
|
1 |
from src.execution_model import ScheduleConfig, ScheduleExecutor
|
2 |
from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
|
3 |
+
from src.visualizer import visualize_pipeline_parallelism_dash
|
4 |
import hydra
|
5 |
from omegaconf import DictConfig, OmegaConf
|
6 |
|
src/visualizer.py
CHANGED
@@ -55,24 +55,42 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
55 |
empty_color = "whitesmoke"
|
56 |
# Colors for task types
|
57 |
def get_color(op_type: str, stage_id: int):
|
58 |
-
#
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
virtual_stage = stage_id // num_devices
|
63 |
|
|
|
|
|
|
|
64 |
if op_type == "forward":
|
65 |
-
|
66 |
-
return forward_base_color
|
67 |
-
else:
|
68 |
-
# Lighter shade for virtual_stage > 0
|
69 |
-
return "lightskyblue"
|
70 |
elif op_type == "backward":
|
71 |
-
|
72 |
-
return backward_base_color
|
73 |
-
else:
|
74 |
-
# Lighter shade for virtual_stage > 0
|
75 |
-
return "lightseagreen"
|
76 |
else:
|
77 |
raise ValueError(f"Invalid operation type: {op_type}")
|
78 |
|
@@ -165,10 +183,32 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
165 |
progress_bar.update(1)
|
166 |
|
167 |
# Add custom legend
|
168 |
-
legend_items = [
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
for i, item in enumerate(legend_items):
|
174 |
fig.add_trace(go.Scatter(
|
@@ -209,14 +249,17 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
209 |
font=dict(size=20)
|
210 |
),
|
211 |
legend=dict(
|
212 |
-
orientation="
|
213 |
yanchor="top",
|
214 |
-
y
|
215 |
-
xanchor="
|
216 |
-
x=
|
|
|
|
|
|
|
217 |
),
|
218 |
-
width=
|
219 |
-
height=400, #
|
220 |
bargap=0,
|
221 |
bargroupgap=0,
|
222 |
)
|
@@ -285,7 +328,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
|
|
285 |
def load_graph(_):
|
286 |
# Create the figure when the app loads
|
287 |
return create_pipeline_figure(schedule_data, show_progress=True)
|
288 |
-
|
289 |
@app.callback(
|
290 |
Output("download-image", "data"),
|
291 |
Input("btn-download", "n_clicks"),
|
@@ -326,23 +369,3 @@ def visualize_pipeline_parallelism_dash(
|
|
326 |
app = create_dash_app(schedule)
|
327 |
print(f"Starting Dash app on http://localhost:{port}/")
|
328 |
app.run_server(debug=debug, port=port)
|
329 |
-
|
330 |
-
|
331 |
-
def save_pipeline_visualization_plotly(
|
332 |
-
schedule: Schedule,
|
333 |
-
output_file: str = "pipeline_visualization_plotly.png",
|
334 |
-
):
|
335 |
-
"""
|
336 |
-
Save a static image of the pipeline schedule visualization.
|
337 |
-
|
338 |
-
Args:
|
339 |
-
schedule: Schedule object to visualize
|
340 |
-
output_file: Path to save the image to
|
341 |
-
"""
|
342 |
-
schedule_data = convert_schedule_to_visualization_format(schedule)
|
343 |
-
fig = create_pipeline_figure(schedule_data, show_progress=True)
|
344 |
-
|
345 |
-
print(f"Saving visualization to {output_file}...")
|
346 |
-
fig.write_image(output_file, width=1600, height=400, scale=2)
|
347 |
-
print(f"Visualization saved to {output_file}")
|
348 |
-
|
|
|
55 |
empty_color = "whitesmoke"
|
56 |
# Colors for task types
|
57 |
def get_color(op_type: str, stage_id: int):
|
58 |
+
# Color palettes for different virtual stages
|
59 |
+
forward_colors = [
|
60 |
+
"royalblue", # Stage 0
|
61 |
+
"lightskyblue", # Stage 1
|
62 |
+
"cornflowerblue", # Stage 2
|
63 |
+
"steelblue", # Stage 3
|
64 |
+
"dodgerblue", # Stage 4
|
65 |
+
"deepskyblue", # Stage 5
|
66 |
+
"mediumblue", # Stage 6
|
67 |
+
"mediumslateblue",# Stage 7
|
68 |
+
"slateblue", # Stage 8
|
69 |
+
"darkslateblue" # Stage 9
|
70 |
+
]
|
71 |
|
72 |
+
backward_colors = [
|
73 |
+
"lightgreen", # Stage 0
|
74 |
+
"mediumseagreen", # Stage 1
|
75 |
+
"seagreen", # Stage 2
|
76 |
+
"lightseagreen", # Stage 3
|
77 |
+
"mediumaquamarine", # Stage 4
|
78 |
+
"mediumspringgreen", # Stage 5
|
79 |
+
"springgreen", # Stage 6
|
80 |
+
"palegreen", # Stage 7
|
81 |
+
"limegreen", # Stage 8
|
82 |
+
"forestgreen" # Stage 9
|
83 |
+
]
|
84 |
+
|
85 |
virtual_stage = stage_id // num_devices
|
86 |
|
87 |
+
# If virtual_stage is beyond our color list, cycle through the colors
|
88 |
+
color_index = virtual_stage % len(forward_colors)
|
89 |
+
|
90 |
if op_type == "forward":
|
91 |
+
return forward_colors[color_index]
|
|
|
|
|
|
|
|
|
92 |
elif op_type == "backward":
|
93 |
+
return backward_colors[color_index]
|
|
|
|
|
|
|
|
|
94 |
else:
|
95 |
raise ValueError(f"Invalid operation type: {op_type}")
|
96 |
|
|
|
183 |
progress_bar.update(1)
|
184 |
|
185 |
# Add custom legend
|
186 |
+
legend_items = []
|
187 |
+
|
188 |
+
# Find the maximum virtual stage in the data
|
189 |
+
max_virtual_stage = 0
|
190 |
+
for device in schedule_data:
|
191 |
+
for task in schedule_data[device]:
|
192 |
+
virtual_stage = task["stage"] // num_devices
|
193 |
+
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
194 |
+
|
195 |
+
# Add forward and backward items for each virtual stage
|
196 |
+
for vs in range(max_virtual_stage + 1):
|
197 |
+
legend_items.append(dict(
|
198 |
+
name=f"Forward (VS {vs})",
|
199 |
+
color=get_color("forward", vs * num_devices)
|
200 |
+
))
|
201 |
+
legend_items.append(dict(
|
202 |
+
name=f"Backward (VS {vs})",
|
203 |
+
color=get_color("backward", vs * num_devices)
|
204 |
+
))
|
205 |
+
|
206 |
+
# If no tasks found, add default legend items
|
207 |
+
if not legend_items:
|
208 |
+
legend_items = [
|
209 |
+
dict(name="Forward (VS 0)", color=get_color("forward", 0)),
|
210 |
+
dict(name="Backward (VS 0)", color=get_color("backward", 0)),
|
211 |
+
]
|
212 |
|
213 |
for i, item in enumerate(legend_items):
|
214 |
fig.add_trace(go.Scatter(
|
|
|
249 |
font=dict(size=20)
|
250 |
),
|
251 |
legend=dict(
|
252 |
+
orientation="v", # Changed from horizontal to vertical
|
253 |
yanchor="top",
|
254 |
+
y=1.02, # Position at the top
|
255 |
+
xanchor="right",
|
256 |
+
x=1.15, # Position to the right of the plot
|
257 |
+
title=dict(text="<b>Operation Types:</b>"),
|
258 |
+
itemsizing="constant",
|
259 |
+
tracegroupgap=0
|
260 |
),
|
261 |
+
width=1800, # Increase width to accommodate the legend
|
262 |
+
height=400, # Maintain current height
|
263 |
bargap=0,
|
264 |
bargroupgap=0,
|
265 |
)
|
|
|
328 |
def load_graph(_):
|
329 |
# Create the figure when the app loads
|
330 |
return create_pipeline_figure(schedule_data, show_progress=True)
|
331 |
+
|
332 |
@app.callback(
|
333 |
Output("download-image", "data"),
|
334 |
Input("btn-download", "n_clicks"),
|
|
|
369 |
app = create_dash_app(schedule)
|
370 |
print(f"Starting Dash app on http://localhost:{port}/")
|
371 |
app.run_server(debug=debug, port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|