Victarry commited on
Commit
86eaa70
·
1 Parent(s): bb52925

Update implementation for 1F1B overlapping.

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. main.py +2 -1
  3. src/execution_model.py +78 -16
  4. src/strategies.py +20 -22
  5. src/visualizer.py +15 -1
README.md CHANGED
@@ -52,7 +52,7 @@ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
52
 
53
 
54
  Running for 1F1B-batch-overlap strategy:
55
- ```bah
56
  uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
57
  ```
58
  ![1f1b_overlap](assets/1f1b_overlap.png)
 
52
 
53
 
54
  Running for 1F1B-batch-overlap strategy:
55
+ ```bash
56
  uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
57
  ```
58
  ![1f1b_overlap](assets/1f1b_overlap.png)
main.py CHANGED
@@ -105,7 +105,8 @@ def run_1f1b_overlap(cfg: DictConfig) -> None:
105
  )
106
  schedule = generate_1f1b_overlap_schedule(schedule_config)
107
  schedule.execute()
108
- visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
 
109
 
110
 
111
  if __name__ == "__main__":
 
105
  )
106
  schedule = generate_1f1b_overlap_schedule(schedule_config)
107
  schedule.execute()
108
+ schedule.show()
109
+ # visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
110
 
111
 
112
  if __name__ == "__main__":
src/execution_model.py CHANGED
@@ -13,7 +13,52 @@ class Operation:
13
 
14
  self.start_time = None
15
  self.end_time = None
 
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class DeviceQueue:
19
  def __init__(self, stages: List[int], device_id: int):
@@ -45,6 +90,7 @@ class ScheduleConfig:
45
  self.p2p_latency = p2p_latency
46
  self.placement_strategy = placement_strategy
47
  self.split_backward = split_backward
 
48
 
49
  # Initialize default operation times
50
  if self.split_backward:
@@ -104,9 +150,20 @@ class ScheduleConfig:
104
  raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
105
 
106
  def get_op_time(self, op_type: str, stage_id: int):
 
 
 
 
 
 
 
 
 
 
 
 
107
  if op_type not in self.op_times:
108
  raise ValueError(f"Invalid operation type: {op_type}")
109
-
110
  times = self.op_times[op_type]
111
  if isinstance(times, dict):
112
  # If we have stage-specific times, use those
@@ -121,9 +178,9 @@ class ScheduleConfig:
121
  class Schedule:
122
  def __init__(self, config: ScheduleConfig):
123
  self.ops = {} # (batch_id, stage_id, op_type) -> Operation
124
- self.dev_queues: List[DeviceQueue] = []
125
  for dev_id in range(config.num_devices):
126
- self.dev_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
127
  self.config = config
128
 
129
  self.init_operations()
@@ -142,7 +199,7 @@ class Schedule:
142
  def get_op(self, batch_id: int, stage_id: int, op_type: str):
143
  return self.ops[(batch_id, stage_id, op_type)]
144
 
145
- def get_dependencies(self, op: Operation):
146
  deps = []
147
  if op.op_type == "forward":
148
  if op.stage_id > 0:
@@ -179,9 +236,10 @@ class Schedule:
179
  )
180
  )
181
 
182
- device_index = self.dev_queues[op.device_id].ops.index(op)
183
- if device_index > 0:
184
- deps.append((self.dev_queues[op.device_id].ops[device_index - 1], 0.0))
 
185
  return deps
186
 
187
  def show(self):
@@ -192,12 +250,12 @@ class Schedule:
192
  print("\n=== DEVICE QUEUES ===")
193
 
194
  for dev_id in range(self.config.num_devices):
195
- print(f"\nDEVICE {dev_id} (Stages: {self.dev_queues[dev_id].stages}):")
196
  print("-" * 80)
197
  print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}")
198
  print("-" * 80)
199
 
200
- for op in self.dev_queues[dev_id].ops:
201
  op_type = op.op_type
202
  start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A"
203
  end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A"
@@ -207,7 +265,7 @@ class Schedule:
207
  duration = f"{op.end_time - op.start_time:.2f}"
208
 
209
  print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}")
210
-
211
  # Find the total execution time (if timing info is available)
212
  if all(op.end_time is not None for op in self.ops.values()):
213
  total_time = max(op.end_time for op in self.ops.values())
@@ -215,22 +273,26 @@ class Schedule:
215
 
216
  def execute(self):
217
  def execute_op(op: Operation):
 
 
218
  deps = self.get_dependencies(op)
219
  if len(deps) == 0:
220
- op.start_time = 0.0
221
  else:
222
  for dep, gap in deps:
223
  if dep.end_time is None or dep.start_time is None:
224
  execute_op(dep)
225
- op.start_time = max(dep.end_time + gap for dep, gap in deps)
226
- op.end_time = op.start_time + self.config.get_op_time(
227
  op.op_type, op.stage_id
228
- )
229
 
230
- op_num = len(self.dev_queues[0].ops)
231
  for i in range(op_num):
232
  for dev_id in range(self.config.num_devices):
233
- op = self.dev_queues[dev_id].ops[i]
 
 
234
  execute_op(op)
235
 
236
  for op in self.ops.values():
 
13
 
14
  self.start_time = None
15
  self.end_time = None
16
+
17
+ def set_end_time(self, end_time: float):
18
+ self.end_time = end_time
19
+
20
+ def set_start_time(self, start_time: float):
21
+ self.start_time = start_time
22
+
23
+ def __repr__(self) -> str:
24
+ return f"Operation(batch_id={self.batch_id}, stage_id={self.stage_id}, op_type={self.op_type})"
25
 
26
+ class OverlappedOperation:
27
+ """Represents multiple operations that are overlapped/executed concurrently."""
28
+
29
+ def __init__(self, operations: List[Operation]):
30
+ self.operations = operations
31
+ self.device_id = operations[0].device_id
32
+
33
+ # Validate all operations are on the same device
34
+ for op in operations:
35
+ assert op.device_id == self.device_id, "All operations must be on the same device"
36
+
37
+ # Create a combined op_type (e.g., "overlapped_forward_backward")
38
+ self.op_type = "overlapped_" + "_".join([op.op_type for op in operations])
39
+
40
+ # Use the batch_id and stage_id of the first operation for identification
41
+ # (though we'll track all operations internally)
42
+ self.batch_id = operations[0].batch_id
43
+ self.stage_id = operations[0].stage_id
44
+
45
+ # Initialize timing information
46
+ self.start_time = None
47
+ self.end_time = None
48
+
49
+ def set_end_time(self, end_time: float):
50
+ self.end_time = end_time
51
+ for op in self.operations:
52
+ op.set_end_time(end_time)
53
+
54
+ def set_start_time(self, start_time: float):
55
+ self.start_time = start_time
56
+ for op in self.operations:
57
+ op.set_start_time(start_time)
58
+
59
+ def __repr__(self) -> str:
60
+ op_str = ", ".join([f"({op.batch_id},{op.stage_id},{op.op_type})" for op in self.operations])
61
+ return f"OverlappedOperation([{op_str}])"
62
 
63
  class DeviceQueue:
64
  def __init__(self, stages: List[int], device_id: int):
 
90
  self.p2p_latency = p2p_latency
91
  self.placement_strategy = placement_strategy
92
  self.split_backward = split_backward
93
+ self.overlapped_op_times = {}
94
 
95
  # Initialize default operation times
96
  if self.split_backward:
 
150
  raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
151
 
152
  def get_op_time(self, op_type: str, stage_id: int):
153
+ # For overlapped operations, extract the original operation types
154
+ if op_type.startswith("overlapped_"):
155
+ op_parts = op_type.split("_")[1:]
156
+ if len(op_parts) >= 2:
157
+ op_type1, op_type2 = op_parts[0], op_parts[1]
158
+ # Check if we have a specific time for this combination
159
+ if (op_type1, op_type2) in self.overlapped_op_times:
160
+ return self.overlapped_op_times[(op_type1, op_type2)]
161
+ # Otherwise, use the sum of individual times
162
+ return (self.get_op_time(op_type1, stage_id) +
163
+ self.get_op_time(op_type2, stage_id))
164
+
165
  if op_type not in self.op_times:
166
  raise ValueError(f"Invalid operation type: {op_type}")
 
167
  times = self.op_times[op_type]
168
  if isinstance(times, dict):
169
  # If we have stage-specific times, use those
 
178
  class Schedule:
179
  def __init__(self, config: ScheduleConfig):
180
  self.ops = {} # (batch_id, stage_id, op_type) -> Operation
181
+ self.device_queues: List[DeviceQueue] = []
182
  for dev_id in range(config.num_devices):
183
+ self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
184
  self.config = config
185
 
186
  self.init_operations()
 
199
  def get_op(self, batch_id: int, stage_id: int, op_type: str):
200
  return self.ops[(batch_id, stage_id, op_type)]
201
 
202
+ def get_dependencies(self, op: Operation, include_device_dependency=True):
203
  deps = []
204
  if op.op_type == "forward":
205
  if op.stage_id > 0:
 
236
  )
237
  )
238
 
239
+ if include_device_dependency:
240
+ device_index = self.device_queues[op.device_id].ops.index(op)
241
+ if device_index > 0:
242
+ deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
243
  return deps
244
 
245
  def show(self):
 
250
  print("\n=== DEVICE QUEUES ===")
251
 
252
  for dev_id in range(self.config.num_devices):
253
+ print(f"\nDEVICE {dev_id} (Stages: {self.device_queues[dev_id].stages}):")
254
  print("-" * 80)
255
  print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}")
256
  print("-" * 80)
257
 
258
+ for op in self.device_queues[dev_id].ops:
259
  op_type = op.op_type
260
  start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A"
261
  end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A"
 
265
  duration = f"{op.end_time - op.start_time:.2f}"
266
 
267
  print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}")
268
+
269
  # Find the total execution time (if timing info is available)
270
  if all(op.end_time is not None for op in self.ops.values()):
271
  total_time = max(op.end_time for op in self.ops.values())
 
273
 
274
  def execute(self):
275
  def execute_op(op: Operation):
276
+ if op.end_time is not None:
277
+ return
278
  deps = self.get_dependencies(op)
279
  if len(deps) == 0:
280
+ op.set_start_time(0.0)
281
  else:
282
  for dep, gap in deps:
283
  if dep.end_time is None or dep.start_time is None:
284
  execute_op(dep)
285
+ op.set_start_time(max(dep.end_time + gap for dep, gap in deps))
286
+ op.set_end_time(op.start_time + self.config.get_op_time(
287
  op.op_type, op.stage_id
288
+ ))
289
 
290
+ op_num = len(self.device_queues[0].ops)
291
  for i in range(op_num):
292
  for dev_id in range(self.config.num_devices):
293
+ if len(self.device_queues[dev_id].ops) <= i:
294
+ continue
295
+ op = self.device_queues[dev_id].ops[i]
296
  execute_op(op)
297
 
298
  for op in self.ops.values():
src/strategies.py CHANGED
@@ -1,5 +1,5 @@
1
  from collections import defaultdict
2
- from src.execution_model import Schedule, ScheduleConfig
3
 
4
 
5
  def generate_1f1b_schedule(config: ScheduleConfig):
@@ -14,23 +14,23 @@ def generate_1f1b_schedule(config: ScheduleConfig):
14
  steady_batches = config.num_batches - warmup_batches
15
 
16
  for _ in range(warmup_batches):
17
- schedule.dev_queues[i].add_operation(
18
  schedule.get_op(fwd_batch_id, i, "forward")
19
  )
20
  fwd_batch_id += 1
21
 
22
  for _ in range(steady_batches):
23
- schedule.dev_queues[i].add_operation(
24
  schedule.get_op(fwd_batch_id, i, "forward")
25
  )
26
  fwd_batch_id += 1
27
- schedule.dev_queues[i].add_operation(
28
  schedule.get_op(bwd_batch_id, i, "backward")
29
  )
30
  bwd_batch_id += 1
31
 
32
  for _ in range(cooldown_batches):
33
- schedule.dev_queues[i].add_operation(
34
  schedule.get_op(bwd_batch_id, i, "backward")
35
  )
36
  bwd_batch_id += 1
@@ -53,20 +53,20 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
53
  steady_batches = total_batches - warmup_batches
54
 
55
  for _ in range(warmup_batches):
56
- schedule.dev_queues[i].add_operation(
57
  schedule.get_op(fwd_batch_id, i, "forward")
58
  )
59
  fwd_batch_id += 1
60
 
61
  for _ in range(steady_batches):
62
- schedule.dev_queues[i].add_operation(
63
  schedule.get_op(fwd_batch_id, i, "forward")
64
  )
65
- schedule.dev_queues[i].add_operation(
66
  schedule.get_op(bwd_d_batch_id, i, "backward_D")
67
  )
68
  if fwd_batch_id - bwd_w_batch_id >= config.num_devices - 1:
69
- schedule.dev_queues[i].add_operation(
70
  schedule.get_op(bwd_w_batch_id, i, "backward_W")
71
  )
72
  bwd_w_batch_id += 1
@@ -74,11 +74,11 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
74
  fwd_batch_id += 1
75
 
76
  for _ in range(cooldown_batches):
77
- schedule.dev_queues[i].add_operation(
78
  schedule.get_op(bwd_d_batch_id, i, "backward_D")
79
  )
80
 
81
- schedule.dev_queues[i].add_operation(
82
  schedule.get_op(bwd_w_batch_id, i, "backward_W")
83
  )
84
 
@@ -86,7 +86,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
86
  bwd_d_batch_id += 1
87
 
88
  while bwd_w_batch_id < total_batches:
89
- schedule.dev_queues[i].add_operation(
90
  schedule.get_op(bwd_w_batch_id, i, "backward_W")
91
  )
92
  bwd_w_batch_id += 1
@@ -106,23 +106,21 @@ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
106
  steady_batches = config.num_batches - warmup_batches
107
 
108
  for _ in range(warmup_batches):
109
- schedule.dev_queues[i].add_operation(
110
  schedule.get_op(fwd_batch_id, i, "forward")
111
  )
112
  fwd_batch_id += 1
113
 
114
  for _ in range(steady_batches):
115
- schedule.dev_queues[i].add_operation(
116
- schedule.get_op(fwd_batch_id, i, "forward")
117
- )
 
118
  fwd_batch_id += 1
119
- schedule.dev_queues[i].add_operation(
120
- schedule.get_op(bwd_batch_id, i, "backward")
121
- )
122
  bwd_batch_id += 1
123
 
124
  for _ in range(cooldown_batches):
125
- schedule.dev_queues[i].add_operation(
126
  schedule.get_op(bwd_batch_id, i, "backward")
127
  )
128
  bwd_batch_id += 1
@@ -264,7 +262,7 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
264
  cur_stage_microbatch_id[i] = 0
265
  cur_stage_microbatch_id[-i] = 0
266
  for order_item in order:
267
- stage_id = schedule.dev_queues[device_id].stages[abs(order_item)-1]
268
 
269
  if order_item > 0:
270
  op_type = "forward"
@@ -276,7 +274,7 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
276
  cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
277
  else:
278
  raise ValueError(f"Invalid order item: {order_item}")
279
- schedule.dev_queues[device_id].add_operation(
280
  schedule.get_op(micro_batch_id, stage_id, op_type)
281
  )
282
  return schedule
 
1
  from collections import defaultdict
2
+ from src.execution_model import OverlappedOperation, Schedule, ScheduleConfig
3
 
4
 
5
  def generate_1f1b_schedule(config: ScheduleConfig):
 
14
  steady_batches = config.num_batches - warmup_batches
15
 
16
  for _ in range(warmup_batches):
17
+ schedule.device_queues[i].add_operation(
18
  schedule.get_op(fwd_batch_id, i, "forward")
19
  )
20
  fwd_batch_id += 1
21
 
22
  for _ in range(steady_batches):
23
+ schedule.device_queues[i].add_operation(
24
  schedule.get_op(fwd_batch_id, i, "forward")
25
  )
26
  fwd_batch_id += 1
27
+ schedule.device_queues[i].add_operation(
28
  schedule.get_op(bwd_batch_id, i, "backward")
29
  )
30
  bwd_batch_id += 1
31
 
32
  for _ in range(cooldown_batches):
33
+ schedule.device_queues[i].add_operation(
34
  schedule.get_op(bwd_batch_id, i, "backward")
35
  )
36
  bwd_batch_id += 1
 
53
  steady_batches = total_batches - warmup_batches
54
 
55
  for _ in range(warmup_batches):
56
+ schedule.device_queues[i].add_operation(
57
  schedule.get_op(fwd_batch_id, i, "forward")
58
  )
59
  fwd_batch_id += 1
60
 
61
  for _ in range(steady_batches):
62
+ schedule.device_queues[i].add_operation(
63
  schedule.get_op(fwd_batch_id, i, "forward")
64
  )
65
+ schedule.device_queues[i].add_operation(
66
  schedule.get_op(bwd_d_batch_id, i, "backward_D")
67
  )
68
  if fwd_batch_id - bwd_w_batch_id >= config.num_devices - 1:
69
+ schedule.device_queues[i].add_operation(
70
  schedule.get_op(bwd_w_batch_id, i, "backward_W")
71
  )
72
  bwd_w_batch_id += 1
 
74
  fwd_batch_id += 1
75
 
76
  for _ in range(cooldown_batches):
77
+ schedule.device_queues[i].add_operation(
78
  schedule.get_op(bwd_d_batch_id, i, "backward_D")
79
  )
80
 
81
+ schedule.device_queues[i].add_operation(
82
  schedule.get_op(bwd_w_batch_id, i, "backward_W")
83
  )
84
 
 
86
  bwd_d_batch_id += 1
87
 
88
  while bwd_w_batch_id < total_batches:
89
+ schedule.device_queues[i].add_operation(
90
  schedule.get_op(bwd_w_batch_id, i, "backward_W")
91
  )
92
  bwd_w_batch_id += 1
 
106
  steady_batches = config.num_batches - warmup_batches
107
 
108
  for _ in range(warmup_batches):
109
+ schedule.device_queues[i].add_operation(
110
  schedule.get_op(fwd_batch_id, i, "forward")
111
  )
112
  fwd_batch_id += 1
113
 
114
  for _ in range(steady_batches):
115
+ fwd_op = schedule.get_op(fwd_batch_id, i, "forward")
116
+ bwd_op = schedule.get_op(bwd_batch_id, i, "backward")
117
+ schedule.device_queues[i].add_operation(OverlappedOperation([fwd_op, bwd_op]))
118
+
119
  fwd_batch_id += 1
 
 
 
120
  bwd_batch_id += 1
121
 
122
  for _ in range(cooldown_batches):
123
+ schedule.device_queues[i].add_operation(
124
  schedule.get_op(bwd_batch_id, i, "backward")
125
  )
126
  bwd_batch_id += 1
 
262
  cur_stage_microbatch_id[i] = 0
263
  cur_stage_microbatch_id[-i] = 0
264
  for order_item in order:
265
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
266
 
267
  if order_item > 0:
268
  op_type = "forward"
 
274
  cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
275
  else:
276
  raise ValueError(f"Invalid order item: {order_item}")
277
+ schedule.device_queues[device_id].add_operation(
278
  schedule.get_op(micro_batch_id, stage_id, op_type)
279
  )
280
  return schedule
src/visualizer.py CHANGED
@@ -5,6 +5,8 @@ import plotly.graph_objects as go
5
  from typing import List, Dict
6
  from tqdm import tqdm
7
  from functools import lru_cache
 
 
8
 
9
  from src.execution_model import Schedule
10
 
@@ -26,7 +28,7 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
26
  visualization_data = {}
27
 
28
  # Organize operations by device
29
- for device_id, device_queue in enumerate(schedule.dev_queues):
30
  visualization_data[device_id] = []
31
 
32
  for op in device_queue.ops:
@@ -494,6 +496,7 @@ def visualize_pipeline_parallelism_dash(
494
  debug: bool = False,
495
  enable_caching: bool = True,
496
  schedule_type="1f1b",
 
497
  ):
498
  """
499
  Launch a Dash app to visualize the pipeline schedule interactively.
@@ -504,9 +507,20 @@ def visualize_pipeline_parallelism_dash(
504
  debug: Whether to run the Dash app in debug mode
505
  enable_caching: Whether to cache schedule data and figures
506
  schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
 
507
  """
508
  app = create_dash_app(
509
  schedule, schedule_type=schedule_type, enable_caching=enable_caching
510
  )
 
 
 
 
 
 
 
 
 
 
511
  print(f"Starting Dash app on http://localhost:{port}/")
512
  app.run_server(debug=debug, port=port)
 
5
  from typing import List, Dict
6
  from tqdm import tqdm
7
  from functools import lru_cache
8
+ import webbrowser
9
+ from threading import Timer
10
 
11
  from src.execution_model import Schedule
12
 
 
28
  visualization_data = {}
29
 
30
  # Organize operations by device
31
+ for device_id, device_queue in enumerate(schedule.device_queues):
32
  visualization_data[device_id] = []
33
 
34
  for op in device_queue.ops:
 
496
  debug: bool = False,
497
  enable_caching: bool = True,
498
  schedule_type="1f1b",
499
+ open_browser: bool = True,
500
  ):
501
  """
502
  Launch a Dash app to visualize the pipeline schedule interactively.
 
507
  debug: Whether to run the Dash app in debug mode
508
  enable_caching: Whether to cache schedule data and figures
509
  schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
510
+ open_browser: Whether to automatically open a browser window
511
  """
512
  app = create_dash_app(
513
  schedule, schedule_type=schedule_type, enable_caching=enable_caching
514
  )
515
+
516
+ # Define function to open browser after a short delay
517
+ def open_browser_tab():
518
+ webbrowser.open_new_tab(f"http://localhost:{port}/")
519
+
520
+ # Open browser automatically if requested
521
+ if open_browser:
522
+ # Use a timer to open the browser after the server has started
523
+ Timer(1.0, open_browser_tab).start()
524
+
525
  print(f"Starting Dash app on http://localhost:{port}/")
526
  app.run_server(debug=debug, port=port)