Victarry commited on
Commit
f140d7b
·
1 Parent(s): 86eaa70

Add visualization for 1F1B overlap.

Browse files
Files changed (4) hide show
  1. main.py +1 -2
  2. src/execution_model.py +20 -3
  3. src/strategies.py +3 -1
  4. src/visualizer.py +100 -13
main.py CHANGED
@@ -105,8 +105,7 @@ def run_1f1b_overlap(cfg: DictConfig) -> None:
105
  )
106
  schedule = generate_1f1b_overlap_schedule(schedule_config)
107
  schedule.execute()
108
- schedule.show()
109
- # visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
110
 
111
 
112
  if __name__ == "__main__":
 
105
  )
106
  schedule = generate_1f1b_overlap_schedule(schedule_config)
107
  schedule.execute()
108
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
 
109
 
110
 
111
  if __name__ == "__main__":
src/execution_model.py CHANGED
@@ -158,9 +158,8 @@ class ScheduleConfig:
158
  # Check if we have a specific time for this combination
159
  if (op_type1, op_type2) in self.overlapped_op_times:
160
  return self.overlapped_op_times[(op_type1, op_type2)]
161
- # Otherwise, use the sum of individual times
162
- return (self.get_op_time(op_type1, stage_id) +
163
- self.get_op_time(op_type2, stage_id))
164
 
165
  if op_type not in self.op_times:
166
  raise ValueError(f"Invalid operation type: {op_type}")
@@ -184,6 +183,12 @@ class Schedule:
184
  self.config = config
185
 
186
  self.init_operations()
 
 
 
 
 
 
187
 
188
  def init_operations(self):
189
  op_types = ["forward", "backward"]
@@ -197,10 +202,21 @@ class Schedule:
197
  )
198
 
199
  def get_op(self, batch_id: int, stage_id: int, op_type: str):
 
 
200
  return self.ops[(batch_id, stage_id, op_type)]
201
 
202
  def get_dependencies(self, op: Operation, include_device_dependency=True):
203
  deps = []
 
 
 
 
 
 
 
 
 
204
  if op.op_type == "forward":
205
  if op.stage_id > 0:
206
  deps.append(
@@ -272,6 +288,7 @@ class Schedule:
272
  print(f"\nTotal execution time: {total_time:.2f}")
273
 
274
  def execute(self):
 
275
  def execute_op(op: Operation):
276
  if op.end_time is not None:
277
  return
 
158
  # Check if we have a specific time for this combination
159
  if (op_type1, op_type2) in self.overlapped_op_times:
160
  return self.overlapped_op_times[(op_type1, op_type2)]
161
+ # Otherwise, use the max of individual times plus a small overhead
162
+ return max(self.get_op_time(op_type1, stage_id), self.get_op_time(op_type2, stage_id)) + 0.2
 
163
 
164
  if op_type not in self.op_times:
165
  raise ValueError(f"Invalid operation type: {op_type}")
 
183
  self.config = config
184
 
185
  self.init_operations()
186
+ self.op_to_overlapped = {}
187
+
188
+ def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
189
+ for op in overlapped_op.operations:
190
+ self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
191
+ self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
192
 
193
  def init_operations(self):
194
  op_types = ["forward", "backward"]
 
202
  )
203
 
204
  def get_op(self, batch_id: int, stage_id: int, op_type: str):
205
+ if (batch_id, stage_id, op_type) in self.op_to_overlapped:
206
+ return self.op_to_overlapped[(batch_id, stage_id, op_type)]
207
  return self.ops[(batch_id, stage_id, op_type)]
208
 
209
  def get_dependencies(self, op: Operation, include_device_dependency=True):
210
  deps = []
211
+ if isinstance(op, OverlappedOperation):
212
+ for sub_op in op.operations:
213
+ deps.extend(self.get_dependencies(sub_op, include_device_dependency=False))
214
+
215
+ if include_device_dependency:
216
+ device_index = self.device_queues[op.device_id].ops.index(op)
217
+ if device_index > 0:
218
+ deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
219
+ return deps
220
  if op.op_type == "forward":
221
  if op.stage_id > 0:
222
  deps.append(
 
288
  print(f"\nTotal execution time: {total_time:.2f}")
289
 
290
  def execute(self):
291
+ # TODO: change the execution order to topological order via DAG
292
  def execute_op(op: Operation):
293
  if op.end_time is not None:
294
  return
src/strategies.py CHANGED
@@ -114,7 +114,9 @@ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
114
  for _ in range(steady_batches):
115
  fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
116
  bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
117
- schedule.device_queues[i].add_operation(OverlappedOperation([fwd_op, bwd_op]))
 
 
118
 
119
  fwd_batch_id += 1
120
  bwd_batch_id += 1
 
114
  for _ in range(steady_batches):
115
  fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
116
  bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
117
+ overlapped_op = OverlappedOperation([fwd_op, bwd_op])
118
+ schedule.register_overlapped_operation(overlapped_op)
119
+ schedule.device_queues[i].add_operation(overlapped_op)
120
 
121
  fwd_batch_id += 1
122
  bwd_batch_id += 1
src/visualizer.py CHANGED
@@ -8,7 +8,7 @@ from functools import lru_cache
8
  import webbrowser
9
  from threading import Timer
10
 
11
- from src.execution_model import Schedule
12
 
13
 
14
  def convert_schedule_to_visualization_format(schedule: Schedule):
@@ -32,15 +32,37 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
32
  visualization_data[device_id] = []
33
 
34
  for op in device_queue.ops:
35
- visualization_data[device_id].append(
36
- {
37
- "type": op.op_type,
38
- "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
39
- "stage": op.stage_id,
40
- "start_time": op.start_time,
41
- "duration": op.end_time - op.start_time,
42
- }
43
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  return visualization_data
46
 
@@ -103,13 +125,30 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
103
  "#99cc99", # Pale green
104
  "#c6e6c6", # Pastel green
105
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  virtual_stage = stage_id // num_devices
108
 
109
  # If virtual_stage is beyond our color list, cycle through the colors
110
  color_index = virtual_stage % len(forward_colors)
111
 
112
- if op_type == "forward":
 
 
 
113
  return forward_colors[color_index]
114
  elif op_type == "backward":
115
  return backward_colors[color_index]
@@ -191,6 +230,14 @@ def create_pipeline_figure(
191
  color = get_color(task["type"], task["stage"], num_devices)
192
  text_color = "black"
193
  name = "Backward (Weight)"
 
 
 
 
 
 
 
 
194
  else:
195
  color = empty_color
196
  text_color = "black"
@@ -222,14 +269,34 @@ def create_pipeline_figure(
222
  dict(
223
  x=start_time + duration / 2,
224
  y=y_pos,
225
- text=f"{task['batch']}",
226
  showarrow=False,
227
  font=dict(color=text_color, size=12, family="Arial, bold"),
228
  )
229
  )
230
 
231
  # Prepare hover data (add traces in batches later)
232
- hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  hover_traces.append(
235
  dict(
@@ -268,6 +335,13 @@ def create_pipeline_figure(
268
  virtual_stage = task["stage"] // num_devices
269
  max_virtual_stage = max(max_virtual_stage, virtual_stage)
270
 
 
 
 
 
 
 
 
271
  # Add forward and backward items for each virtual stage
272
  for vs in range(max_virtual_stage + 1):
273
  legend_items.append(
@@ -300,6 +374,15 @@ def create_pipeline_figure(
300
  color=get_color("backward_W", vs * num_devices, num_devices),
301
  )
302
  )
 
 
 
 
 
 
 
 
 
303
 
304
  # If no tasks found, add default legend items
305
  if not legend_items:
@@ -314,6 +397,10 @@ def create_pipeline_figure(
314
  name="Backward Weight (VS 0)",
315
  color=get_color("backward_W", 0, num_devices),
316
  ),
 
 
 
 
317
  ]
318
 
319
  for i, item in enumerate(legend_items):
 
8
  import webbrowser
9
  from threading import Timer
10
 
11
+ from src.execution_model import Schedule, OverlappedOperation
12
 
13
 
14
  def convert_schedule_to_visualization_format(schedule: Schedule):
 
32
  visualization_data[device_id] = []
33
 
34
  for op in device_queue.ops:
35
+ # Handle both regular Operations and OverlappedOperations
36
+ if isinstance(op, OverlappedOperation):
37
+ visualization_data[device_id].append(
38
+ {
39
+ "type": op.op_type,
40
+ "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
41
+ "stage": op.stage_id,
42
+ "start_time": op.start_time,
43
+ "duration": op.end_time - op.start_time,
44
+ "is_overlapped": True,
45
+ "operations": [
46
+ {
47
+ "type": nested_op.op_type,
48
+ "batch": nested_op.batch_id + 1,
49
+ "stage": nested_op.stage_id
50
+ }
51
+ for nested_op in op.operations
52
+ ]
53
+ }
54
+ )
55
+ else:
56
+ visualization_data[device_id].append(
57
+ {
58
+ "type": op.op_type,
59
+ "batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
60
+ "stage": op.stage_id,
61
+ "start_time": op.start_time,
62
+ "duration": op.end_time - op.start_time,
63
+ "is_overlapped": False
64
+ }
65
+ )
66
 
67
  return visualization_data
68
 
 
125
  "#99cc99", # Pale green
126
  "#c6e6c6", # Pastel green
127
  ]
128
+
129
+ # Purple palette for overlapped operations
130
+ overlapped_colors = [
131
+ "#9966cc", # Medium purple
132
+ "#8a2be2", # Blue violet
133
+ "#9370db", # Medium purple
134
+ "#6a5acd", # Slate blue
135
+ "#7b68ee", # Medium slate blue
136
+ "#ba55d3", # Medium orchid
137
+ "#9932cc", # Dark orchid
138
+ "#d8bfd8", # Thistle
139
+ "#e6e6fa", # Lavender
140
+ "#dda0dd", # Plum
141
+ ]
142
 
143
  virtual_stage = stage_id // num_devices
144
 
145
  # If virtual_stage is beyond our color list, cycle through the colors
146
  color_index = virtual_stage % len(forward_colors)
147
 
148
+ # Handle overlapped operations
149
+ if op_type.startswith("overlapped_"):
150
+ return overlapped_colors[color_index]
151
+ elif op_type == "forward":
152
  return forward_colors[color_index]
153
  elif op_type == "backward":
154
  return backward_colors[color_index]
 
230
  color = get_color(task["type"], task["stage"], num_devices)
231
  text_color = "black"
232
  name = "Backward (Weight)"
233
+ elif task["type"].startswith("overlapped_"):
234
+ color = get_color(task["type"], task["stage"], num_devices)
235
+ text_color = "white"
236
+ name = "Overlapped"
237
+ # Create a more descriptive name for the hover text
238
+ if "is_overlapped" in task and task["is_overlapped"]:
239
+ op_types = [op["type"] for op in task["operations"]]
240
+ name = f"Overlapped ({', '.join(op_types)})"
241
  else:
242
  color = empty_color
243
  text_color = "black"
 
269
  dict(
270
  x=start_time + duration / 2,
271
  y=y_pos,
272
+ text=f"{task['batch']}" + ("*" if task.get("is_overlapped", False) else ""),
273
  showarrow=False,
274
  font=dict(color=text_color, size=12, family="Arial, bold"),
275
  )
276
  )
277
 
278
  # Prepare hover data (add traces in batches later)
279
+ if task.get("is_overlapped", False):
280
+ # Enhanced hover text for overlapped operations
281
+ op_details = "<br>".join([
282
+ f"- {op['type']} (Batch {op['batch']}, Stage {op['stage']})"
283
+ for op in task["operations"]
284
+ ])
285
+ hover_text = (
286
+ f"Overlapped Operations:<br>{op_details}<br>"
287
+ f"Start: {task['start_time']:.2f}<br>"
288
+ f"End: {task['start_time'] + task['duration']:.2f}<br>"
289
+ f"Duration: {task['duration']:.2f}"
290
+ )
291
+ else:
292
+ hover_text = (
293
+ f"Batch: {task['batch']}<br>"
294
+ f"Stage: {task['stage']}<br>"
295
+ f"Type: {name}<br>"
296
+ f"Start: {task['start_time']:.2f}<br>"
297
+ f"End: {task['start_time'] + task['duration']:.2f}<br>"
298
+ f"Duration: {task['duration']:.2f}"
299
+ )
300
 
301
  hover_traces.append(
302
  dict(
 
335
  virtual_stage = task["stage"] // num_devices
336
  max_virtual_stage = max(max_virtual_stage, virtual_stage)
337
 
338
+ # Check if overlapped operations exist
339
+ has_overlapped = any(
340
+ task.get("is_overlapped", False)
341
+ for device in schedule_data
342
+ for task in schedule_data[device]
343
+ )
344
+
345
  # Add forward and backward items for each virtual stage
346
  for vs in range(max_virtual_stage + 1):
347
  legend_items.append(
 
374
  color=get_color("backward_W", vs * num_devices, num_devices),
375
  )
376
  )
377
+
378
+ # Add entry for overlapped operations if they exist
379
+ if has_overlapped:
380
+ legend_items.append(
381
+ dict(
382
+ name=f"Overlapped (VS {vs})",
383
+ color=get_color("overlapped_", vs * num_devices, num_devices),
384
+ )
385
+ )
386
 
387
  # If no tasks found, add default legend items
388
  if not legend_items:
 
397
  name="Backward Weight (VS 0)",
398
  color=get_color("backward_W", 0, num_devices),
399
  ),
400
+ dict(
401
+ name="Overlapped (VS 0)",
402
+ color=get_color("overlapped_", 0, num_devices),
403
+ ),
404
  ]
405
 
406
  for i, item in enumerate(legend_items):