Spaces:
Running
Running
Add visualization for 1F1B overlap.
Browse files- main.py +1 -2
- src/execution_model.py +20 -3
- src/strategies.py +3 -1
- src/visualizer.py +100 -13
main.py
CHANGED
@@ -105,8 +105,7 @@ def run_1f1b_overlap(cfg: DictConfig) -> None:
|
|
105 |
)
|
106 |
schedule = generate_1f1b_overlap_schedule(schedule_config)
|
107 |
schedule.execute()
|
108 |
-
schedule.
|
109 |
-
# visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
110 |
|
111 |
|
112 |
if __name__ == "__main__":
|
|
|
105 |
)
|
106 |
schedule = generate_1f1b_overlap_schedule(schedule_config)
|
107 |
schedule.execute()
|
108 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
|
|
109 |
|
110 |
|
111 |
if __name__ == "__main__":
|
src/execution_model.py
CHANGED
@@ -158,9 +158,8 @@ class ScheduleConfig:
|
|
158 |
# Check if we have a specific time for this combination
|
159 |
if (op_type1, op_type2) in self.overlapped_op_times:
|
160 |
return self.overlapped_op_times[(op_type1, op_type2)]
|
161 |
-
# Otherwise, use the
|
162 |
-
return (self.get_op_time(op_type1, stage_id) +
|
163 |
-
self.get_op_time(op_type2, stage_id))
|
164 |
|
165 |
if op_type not in self.op_times:
|
166 |
raise ValueError(f"Invalid operation type: {op_type}")
|
@@ -184,6 +183,12 @@ class Schedule:
|
|
184 |
self.config = config
|
185 |
|
186 |
self.init_operations()
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
def init_operations(self):
|
189 |
op_types = ["forward", "backward"]
|
@@ -197,10 +202,21 @@ class Schedule:
|
|
197 |
)
|
198 |
|
199 |
def get_op(self, batch_id: int, stage_id: int, op_type: str):
|
|
|
|
|
200 |
return self.ops[(batch_id, stage_id, op_type)]
|
201 |
|
202 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
203 |
deps = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
if op.op_type == "forward":
|
205 |
if op.stage_id > 0:
|
206 |
deps.append(
|
@@ -272,6 +288,7 @@ class Schedule:
|
|
272 |
print(f"\nTotal execution time: {total_time:.2f}")
|
273 |
|
274 |
def execute(self):
|
|
|
275 |
def execute_op(op: Operation):
|
276 |
if op.end_time is not None:
|
277 |
return
|
|
|
158 |
# Check if we have a specific time for this combination
|
159 |
if (op_type1, op_type2) in self.overlapped_op_times:
|
160 |
return self.overlapped_op_times[(op_type1, op_type2)]
|
161 |
+
# Otherwise, use the max of individual times plus a small overhead
|
162 |
+
return max(self.get_op_time(op_type1, stage_id), self.get_op_time(op_type2, stage_id)) + 0.2
|
|
|
163 |
|
164 |
if op_type not in self.op_times:
|
165 |
raise ValueError(f"Invalid operation type: {op_type}")
|
|
|
183 |
self.config = config
|
184 |
|
185 |
self.init_operations()
|
186 |
+
self.op_to_overlapped = {}
|
187 |
+
|
188 |
+
def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
|
189 |
+
for op in overlapped_op.operations:
|
190 |
+
self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
191 |
+
self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
192 |
|
193 |
def init_operations(self):
|
194 |
op_types = ["forward", "backward"]
|
|
|
202 |
)
|
203 |
|
204 |
def get_op(self, batch_id: int, stage_id: int, op_type: str):
|
205 |
+
if (batch_id, stage_id, op_type) in self.op_to_overlapped:
|
206 |
+
return self.op_to_overlapped[(batch_id, stage_id, op_type)]
|
207 |
return self.ops[(batch_id, stage_id, op_type)]
|
208 |
|
209 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
210 |
deps = []
|
211 |
+
if isinstance(op, OverlappedOperation):
|
212 |
+
for sub_op in op.operations:
|
213 |
+
deps.extend(self.get_dependencies(sub_op, include_device_dependency=False))
|
214 |
+
|
215 |
+
if include_device_dependency:
|
216 |
+
device_index = self.device_queues[op.device_id].ops.index(op)
|
217 |
+
if device_index > 0:
|
218 |
+
deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
|
219 |
+
return deps
|
220 |
if op.op_type == "forward":
|
221 |
if op.stage_id > 0:
|
222 |
deps.append(
|
|
|
288 |
print(f"\nTotal execution time: {total_time:.2f}")
|
289 |
|
290 |
def execute(self):
|
291 |
+
# TODO: change the execution order to topological order via DAG
|
292 |
def execute_op(op: Operation):
|
293 |
if op.end_time is not None:
|
294 |
return
|
src/strategies.py
CHANGED
@@ -114,7 +114,9 @@ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
|
114 |
for _ in range(steady_batches):
|
115 |
fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
|
116 |
bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
|
117 |
-
|
|
|
|
|
118 |
|
119 |
fwd_batch_id += 1
|
120 |
bwd_batch_id += 1
|
|
|
114 |
for _ in range(steady_batches):
|
115 |
fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
|
116 |
bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
|
117 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
118 |
+
schedule.register_overlapped_operation(overlapped_op)
|
119 |
+
schedule.device_queues[i].add_operation(overlapped_op)
|
120 |
|
121 |
fwd_batch_id += 1
|
122 |
bwd_batch_id += 1
|
src/visualizer.py
CHANGED
@@ -8,7 +8,7 @@ from functools import lru_cache
|
|
8 |
import webbrowser
|
9 |
from threading import Timer
|
10 |
|
11 |
-
from src.execution_model import Schedule
|
12 |
|
13 |
|
14 |
def convert_schedule_to_visualization_format(schedule: Schedule):
|
@@ -32,15 +32,37 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
32 |
visualization_data[device_id] = []
|
33 |
|
34 |
for op in device_queue.ops:
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
return visualization_data
|
46 |
|
@@ -103,13 +125,30 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
103 |
"#99cc99", # Pale green
|
104 |
"#c6e6c6", # Pastel green
|
105 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
virtual_stage = stage_id // num_devices
|
108 |
|
109 |
# If virtual_stage is beyond our color list, cycle through the colors
|
110 |
color_index = virtual_stage % len(forward_colors)
|
111 |
|
112 |
-
|
|
|
|
|
|
|
113 |
return forward_colors[color_index]
|
114 |
elif op_type == "backward":
|
115 |
return backward_colors[color_index]
|
@@ -191,6 +230,14 @@ def create_pipeline_figure(
|
|
191 |
color = get_color(task["type"], task["stage"], num_devices)
|
192 |
text_color = "black"
|
193 |
name = "Backward (Weight)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
else:
|
195 |
color = empty_color
|
196 |
text_color = "black"
|
@@ -222,14 +269,34 @@ def create_pipeline_figure(
|
|
222 |
dict(
|
223 |
x=start_time + duration / 2,
|
224 |
y=y_pos,
|
225 |
-
text=f"{task['batch']}",
|
226 |
showarrow=False,
|
227 |
font=dict(color=text_color, size=12, family="Arial, bold"),
|
228 |
)
|
229 |
)
|
230 |
|
231 |
# Prepare hover data (add traces in batches later)
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
hover_traces.append(
|
235 |
dict(
|
@@ -268,6 +335,13 @@ def create_pipeline_figure(
|
|
268 |
virtual_stage = task["stage"] // num_devices
|
269 |
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
# Add forward and backward items for each virtual stage
|
272 |
for vs in range(max_virtual_stage + 1):
|
273 |
legend_items.append(
|
@@ -300,6 +374,15 @@ def create_pipeline_figure(
|
|
300 |
color=get_color("backward_W", vs * num_devices, num_devices),
|
301 |
)
|
302 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
# If no tasks found, add default legend items
|
305 |
if not legend_items:
|
@@ -314,6 +397,10 @@ def create_pipeline_figure(
|
|
314 |
name="Backward Weight (VS 0)",
|
315 |
color=get_color("backward_W", 0, num_devices),
|
316 |
),
|
|
|
|
|
|
|
|
|
317 |
]
|
318 |
|
319 |
for i, item in enumerate(legend_items):
|
|
|
8 |
import webbrowser
|
9 |
from threading import Timer
|
10 |
|
11 |
+
from src.execution_model import Schedule, OverlappedOperation
|
12 |
|
13 |
|
14 |
def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
|
32 |
visualization_data[device_id] = []
|
33 |
|
34 |
for op in device_queue.ops:
|
35 |
+
# Handle both regular Operations and OverlappedOperations
|
36 |
+
if isinstance(op, OverlappedOperation):
|
37 |
+
visualization_data[device_id].append(
|
38 |
+
{
|
39 |
+
"type": op.op_type,
|
40 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
41 |
+
"stage": op.stage_id,
|
42 |
+
"start_time": op.start_time,
|
43 |
+
"duration": op.end_time - op.start_time,
|
44 |
+
"is_overlapped": True,
|
45 |
+
"operations": [
|
46 |
+
{
|
47 |
+
"type": nested_op.op_type,
|
48 |
+
"batch": nested_op.batch_id + 1,
|
49 |
+
"stage": nested_op.stage_id
|
50 |
+
}
|
51 |
+
for nested_op in op.operations
|
52 |
+
]
|
53 |
+
}
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
visualization_data[device_id].append(
|
57 |
+
{
|
58 |
+
"type": op.op_type,
|
59 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
60 |
+
"stage": op.stage_id,
|
61 |
+
"start_time": op.start_time,
|
62 |
+
"duration": op.end_time - op.start_time,
|
63 |
+
"is_overlapped": False
|
64 |
+
}
|
65 |
+
)
|
66 |
|
67 |
return visualization_data
|
68 |
|
|
|
125 |
"#99cc99", # Pale green
|
126 |
"#c6e6c6", # Pastel green
|
127 |
]
|
128 |
+
|
129 |
+
# Purple palette for overlapped operations
|
130 |
+
overlapped_colors = [
|
131 |
+
"#9966cc", # Medium purple
|
132 |
+
"#8a2be2", # Blue violet
|
133 |
+
"#9370db", # Medium purple
|
134 |
+
"#6a5acd", # Slate blue
|
135 |
+
"#7b68ee", # Medium slate blue
|
136 |
+
"#ba55d3", # Medium orchid
|
137 |
+
"#9932cc", # Dark orchid
|
138 |
+
"#d8bfd8", # Thistle
|
139 |
+
"#e6e6fa", # Lavender
|
140 |
+
"#dda0dd", # Plum
|
141 |
+
]
|
142 |
|
143 |
virtual_stage = stage_id // num_devices
|
144 |
|
145 |
# If virtual_stage is beyond our color list, cycle through the colors
|
146 |
color_index = virtual_stage % len(forward_colors)
|
147 |
|
148 |
+
# Handle overlapped operations
|
149 |
+
if op_type.startswith("overlapped_"):
|
150 |
+
return overlapped_colors[color_index]
|
151 |
+
elif op_type == "forward":
|
152 |
return forward_colors[color_index]
|
153 |
elif op_type == "backward":
|
154 |
return backward_colors[color_index]
|
|
|
230 |
color = get_color(task["type"], task["stage"], num_devices)
|
231 |
text_color = "black"
|
232 |
name = "Backward (Weight)"
|
233 |
+
elif task["type"].startswith("overlapped_"):
|
234 |
+
color = get_color(task["type"], task["stage"], num_devices)
|
235 |
+
text_color = "white"
|
236 |
+
name = "Overlapped"
|
237 |
+
# Create a more descriptive name for the hover text
|
238 |
+
if "is_overlapped" in task and task["is_overlapped"]:
|
239 |
+
op_types = [op["type"] for op in task["operations"]]
|
240 |
+
name = f"Overlapped ({', '.join(op_types)})"
|
241 |
else:
|
242 |
color = empty_color
|
243 |
text_color = "black"
|
|
|
269 |
dict(
|
270 |
x=start_time + duration / 2,
|
271 |
y=y_pos,
|
272 |
+
text=f"{task['batch']}" + ("*" if task.get("is_overlapped", False) else ""),
|
273 |
showarrow=False,
|
274 |
font=dict(color=text_color, size=12, family="Arial, bold"),
|
275 |
)
|
276 |
)
|
277 |
|
278 |
# Prepare hover data (add traces in batches later)
|
279 |
+
if task.get("is_overlapped", False):
|
280 |
+
# Enhanced hover text for overlapped operations
|
281 |
+
op_details = "<br>".join([
|
282 |
+
f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
|
283 |
+
for op in task["operations"]
|
284 |
+
])
|
285 |
+
hover_text = (
|
286 |
+
f"Overlapped Operations:<br>{op_details}<br>"
|
287 |
+
f"Start: {task['start_time']:.2f}<br>"
|
288 |
+
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
289 |
+
f"Duration: {task['duration']:.2f}"
|
290 |
+
)
|
291 |
+
else:
|
292 |
+
hover_text = (
|
293 |
+
f"Batch: {task['batch']}<br>"
|
294 |
+
f"Stage: {task['stage']}<br>"
|
295 |
+
f"Type: {name}<br>"
|
296 |
+
f"Start: {task['start_time']:.2f}<br>"
|
297 |
+
f"End: {task['start_time'] + task['duration']:.2f}<br>"
|
298 |
+
f"Duration: {task['duration']:.2f}"
|
299 |
+
)
|
300 |
|
301 |
hover_traces.append(
|
302 |
dict(
|
|
|
335 |
virtual_stage = task["stage"] // num_devices
|
336 |
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
337 |
|
338 |
+
# Check if overlapped operations exist
|
339 |
+
has_overlapped = any(
|
340 |
+
task.get("is_overlapped", False)
|
341 |
+
for device in schedule_data
|
342 |
+
for task in schedule_data[device]
|
343 |
+
)
|
344 |
+
|
345 |
# Add forward and backward items for each virtual stage
|
346 |
for vs in range(max_virtual_stage + 1):
|
347 |
legend_items.append(
|
|
|
374 |
color=get_color("backward_W", vs * num_devices, num_devices),
|
375 |
)
|
376 |
)
|
377 |
+
|
378 |
+
# Add entry for overlapped operations if they exist
|
379 |
+
if has_overlapped:
|
380 |
+
legend_items.append(
|
381 |
+
dict(
|
382 |
+
name=f"Overlapped (VS {vs})",
|
383 |
+
color=get_color("overlapped_", vs * num_devices, num_devices),
|
384 |
+
)
|
385 |
+
)
|
386 |
|
387 |
# If no tasks found, add default legend items
|
388 |
if not legend_items:
|
|
|
397 |
name="Backward Weight (VS 0)",
|
398 |
color=get_color("backward_W", 0, num_devices),
|
399 |
),
|
400 |
+
dict(
|
401 |
+
name="Overlapped (VS 0)",
|
402 |
+
color=get_color("overlapped_", 0, num_devices),
|
403 |
+
),
|
404 |
]
|
405 |
|
406 |
for i, item in enumerate(legend_items):
|