Spaces:
Running
Running
Update visualizer.
Browse files- src/visualizer.py +68 -66
src/visualizer.py
CHANGED
@@ -70,56 +70,47 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
70 |
# Cache the color calculation as it's repeatedly called with the same parameters
|
71 |
@lru_cache(maxsize=128)
|
72 |
def get_color(op_type: str, stage_id: int, num_devices: int):
|
73 |
-
# A more harmonious blue palette with
|
74 |
forward_colors = [
|
75 |
-
"#
|
76 |
-
"#
|
77 |
-
"#
|
78 |
-
"#
|
79 |
-
"#
|
80 |
-
"#809fff", # Medium blue
|
81 |
-
"#99b3e6", # Pale blue
|
82 |
-
"#b3c6ff", # Light blue
|
83 |
]
|
84 |
|
85 |
-
# Orange palette for backward operations
|
86 |
backward_colors = [
|
87 |
-
"#
|
88 |
-
"#
|
89 |
-
"#
|
90 |
-
"#
|
91 |
-
"#ffd699", # Light amber
|
92 |
-
"#ffd6ad", # Pale orange
|
93 |
-
"#ffe0c2", # Very pale orange
|
94 |
-
"#fff0e0", # Lightest orange
|
95 |
]
|
96 |
|
97 |
-
# Improved teal/turquoise palette with
|
98 |
backward_d_colors = [
|
99 |
-
"#
|
100 |
-
"#
|
101 |
-
"#
|
102 |
-
"#
|
103 |
-
"#
|
|
|
|
|
|
|
104 |
"#008080", # Dark teal
|
105 |
-
"#00e6cc", # Turquoise
|
106 |
-
"#4ddbbd", # Aqua
|
107 |
-
"#80d4c8", # Pale teal
|
108 |
-
"#b3e6e0", # Ice
|
109 |
]
|
110 |
|
111 |
-
# Improved green palette with
|
112 |
backward_w_colors = [
|
113 |
-
"#
|
114 |
-
"#
|
115 |
-
"#
|
116 |
-
"#80ffbf", #
|
117 |
-
"#
|
118 |
-
"#
|
119 |
-
"#
|
120 |
-
"#
|
121 |
-
"#
|
122 |
-
"#c6e6c6", # Pastel green
|
123 |
]
|
124 |
|
125 |
virtual_stage = stage_id // num_devices
|
@@ -130,11 +121,11 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
130 |
if op_type == "forward":
|
131 |
return forward_colors[color_index]
|
132 |
elif op_type == "backward":
|
133 |
-
return backward_colors[color_index]
|
134 |
elif op_type == "backward_D":
|
135 |
-
return backward_d_colors[color_index]
|
136 |
elif op_type == "backward_W":
|
137 |
-
return backward_w_colors[color_index]
|
138 |
else:
|
139 |
raise ValueError(f"Invalid operation type: {op_type}")
|
140 |
|
@@ -163,6 +154,15 @@ def create_pipeline_figure(
|
|
163 |
end_time = task["start_time"] + task["duration"]
|
164 |
if end_time > max_time:
|
165 |
max_time = end_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
# Create a figure
|
168 |
fig = go.Figure()
|
@@ -251,22 +251,23 @@ def create_pipeline_figure(
|
|
251 |
)
|
252 |
)
|
253 |
|
254 |
-
# Add batch number text for this sub-operation
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
268 |
)
|
269 |
-
)
|
270 |
else:
|
271 |
# Regular (non-overlapped) operation
|
272 |
# Determine task color and text color
|
@@ -305,16 +306,17 @@ def create_pipeline_figure(
|
|
305 |
)
|
306 |
)
|
307 |
|
308 |
-
# Add batch number text
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
|
|
|
|
316 |
)
|
317 |
-
)
|
318 |
|
319 |
# Prepare hover data
|
320 |
hover_text = (
|
|
|
70 |
# Cache the color calculation as it's repeatedly called with the same parameters
|
71 |
@lru_cache(maxsize=128)
|
72 |
def get_color(op_type: str, stage_id: int, num_devices: int):
|
73 |
+
# A more harmonious blue palette with low saturation and high brightness
|
74 |
forward_colors = [
|
75 |
+
"#0a5aff", # Intense blue
|
76 |
+
"#4c88ff", # Blue (deeper)
|
77 |
+
"#7aa7ff", # Medium blue
|
78 |
+
"#a8c5ff", # Soft blue
|
79 |
+
"#d6e4ff", # Very light blue
|
|
|
|
|
|
|
80 |
]
|
81 |
|
82 |
+
# Orange palette for backward operations with low saturation and high brightness
|
83 |
backward_colors = [
|
84 |
+
"#f47b00", # Intense orange
|
85 |
+
"#ffa952", # Orange
|
86 |
+
"#ffc78e", # Light orange
|
87 |
+
"#ffe6cc", # Very light orange
|
|
|
|
|
|
|
|
|
88 |
]
|
89 |
|
90 |
+
# Improved teal/turquoise palette with low saturation and high brightness
|
91 |
backward_d_colors = [
|
92 |
+
"#ccffff", # Very light cyan
|
93 |
+
"#b3ffff", # Pale cyan
|
94 |
+
"#99ffff", # Light cyan
|
95 |
+
"#80ffff", # Cyan
|
96 |
+
"#66e6e6", # Soft teal
|
97 |
+
"#4dcccc", # Light teal
|
98 |
+
"#33b3b3", # Teal
|
99 |
+
"#009999", # Medium teal
|
100 |
"#008080", # Dark teal
|
|
|
|
|
|
|
|
|
101 |
]
|
102 |
|
103 |
+
# Improved green palette with low saturation and high brightness
|
104 |
backward_w_colors = [
|
105 |
+
"#ccffe6", # Very light mint
|
106 |
+
"#b3ffd9", # Pale mint
|
107 |
+
"#99ffcc", # Light mint
|
108 |
+
"#80ffbf", # Mint green
|
109 |
+
"#66e6a6", # Soft green
|
110 |
+
"#4dcc8c", # Light green
|
111 |
+
"#33b373", # Medium green
|
112 |
+
"#009959", # Forest green
|
113 |
+
"#008040", # Dark green
|
|
|
114 |
]
|
115 |
|
116 |
virtual_stage = stage_id // num_devices
|
|
|
121 |
if op_type == "forward":
|
122 |
return forward_colors[color_index]
|
123 |
elif op_type == "backward":
|
124 |
+
return backward_colors[color_index % len(backward_colors)]
|
125 |
elif op_type == "backward_D":
|
126 |
+
return backward_d_colors[color_index % len(backward_d_colors)]
|
127 |
elif op_type == "backward_W":
|
128 |
+
return backward_w_colors[color_index % len(backward_w_colors)]
|
129 |
else:
|
130 |
raise ValueError(f"Invalid operation type: {op_type}")
|
131 |
|
|
|
154 |
end_time = task["start_time"] + task["duration"]
|
155 |
if end_time > max_time:
|
156 |
max_time = end_time
|
157 |
+
|
158 |
+
# Determine maximum batch number to decide whether to show text labels
|
159 |
+
max_batch = 0
|
160 |
+
for device in schedule_data:
|
161 |
+
for task in schedule_data[device]:
|
162 |
+
max_batch = max(max_batch, task["batch"])
|
163 |
+
|
164 |
+
# Flag to determine whether to show text labels
|
165 |
+
show_text_labels = max_batch <= 16
|
166 |
|
167 |
# Create a figure
|
168 |
fig = go.Figure()
|
|
|
251 |
)
|
252 |
)
|
253 |
|
254 |
+
# Add batch number text for this sub-operation only if show_text_labels is True
|
255 |
+
if show_text_labels:
|
256 |
+
# Determine text color based on background color
|
257 |
+
if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
|
258 |
+
text_color = "black"
|
259 |
+
else:
|
260 |
+
text_color = "white"
|
261 |
+
|
262 |
+
annotations.append(
|
263 |
+
dict(
|
264 |
+
x=start_time + duration / 2,
|
265 |
+
y=sub_y_center,
|
266 |
+
text=f"{sub_op['batch']}",
|
267 |
+
showarrow=False,
|
268 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
269 |
+
)
|
270 |
)
|
|
|
271 |
else:
|
272 |
# Regular (non-overlapped) operation
|
273 |
# Determine task color and text color
|
|
|
306 |
)
|
307 |
)
|
308 |
|
309 |
+
# Add batch number text only if show_text_labels is True
|
310 |
+
if show_text_labels:
|
311 |
+
annotations.append(
|
312 |
+
dict(
|
313 |
+
x=start_time + duration / 2,
|
314 |
+
y=y_pos,
|
315 |
+
text=f"{task['batch']}",
|
316 |
+
showarrow=False,
|
317 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
318 |
+
)
|
319 |
)
|
|
|
320 |
|
321 |
# Prepare hover data
|
322 |
hover_text = (
|