Victarry commited on
Commit
2ae9b28
·
1 Parent(s): 869d773

Refactor schedule execution model and simplify execution flow

Browse files
Files changed (3) hide show
  1. conf/config.yaml +1 -1
  2. main.py +3 -6
  3. src/execution_model.py +10 -12
conf/config.yaml CHANGED
@@ -1,7 +1,7 @@
1
  # Default configuration for Pipeline Parallelism Emulation
2
  num_devices: 4
3
  num_stages: 4
4
- num_batches: 12
5
  visualization_port: 8050
6
  strategy: "1f1b" # Options: "1f1b", "interleave"
7
  p2p_latency: 0.0
 
1
  # Default configuration for Pipeline Parallelism Emulation
2
  num_devices: 4
3
  num_stages: 4
4
+ num_batches: 8
5
  visualization_port: 8050
6
  strategy: "1f1b" # Options: "1f1b", "interleave"
7
  p2p_latency: 0.0
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from src.execution_model import ScheduleConfig, ScheduleExecutor
2
  from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
3
  from src.visualizer import visualize_pipeline_parallelism_dash
4
  import hydra
@@ -32,8 +32,7 @@ def run_1f1b(cfg: DictConfig) -> None:
32
  placement_strategy="standard"
33
  )
34
  schedule = generate_1f1b_schedule(schedule_config)
35
- executor = ScheduleExecutor(schedule)
36
- executor.execute()
37
 
38
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
39
 
@@ -52,9 +51,7 @@ def run_interleave(cfg: DictConfig) -> None:
52
  op_times=op_times
53
  )
54
  schedule = generate_1f1b_interleave_schedule(schedule_config)
55
- executor = ScheduleExecutor(schedule)
56
- executor.execute()
57
-
58
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
59
 
60
 
 
1
+ from src.execution_model import ScheduleConfig
2
  from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
3
  from src.visualizer import visualize_pipeline_parallelism_dash
4
  import hydra
 
32
  placement_strategy="standard"
33
  )
34
  schedule = generate_1f1b_schedule(schedule_config)
35
+ schedule.execute()
 
36
 
37
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
38
 
 
51
  op_times=op_times
52
  )
53
  schedule = generate_1f1b_interleave_schedule(schedule_config)
54
+ schedule.execute()
 
 
55
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
56
 
57
 
src/execution_model.py CHANGED
@@ -184,15 +184,10 @@ class Schedule:
184
  if all(op.end_time is not None for op in self.ops.values()):
185
  total_time = max(op.end_time for op in self.ops.values())
186
  print(f"\nTotal execution time: {total_time:.2f}")
187
-
188
-
189
- class ScheduleExecutor:
190
- def __init__(self, schedule: Schedule):
191
- self.schedule = schedule
192
-
193
  def execute(self):
194
  def execute_op(op: Operation):
195
- deps = self.schedule.get_dependencies(op)
196
  if len(deps) == 0:
197
  op.start_time = 0.0
198
  else:
@@ -200,20 +195,23 @@ class ScheduleExecutor:
200
  if dep.end_time is None or dep.start_time is None:
201
  execute_op(dep)
202
  op.start_time = max(dep.end_time + gap for dep, gap in deps)
203
- op.end_time = op.start_time + self.schedule.config.get_op_time(
204
  op.op_type, op.stage_id
205
  )
206
 
207
- op_num = len(self.schedule.dev_queues[0].ops)
208
  for i in range(op_num):
209
- for dev_id in range(self.schedule.config.num_devices):
210
- op = self.schedule.dev_queues[dev_id].ops[i]
211
  execute_op(op)
212
 
213
- for op in self.schedule.ops.values():
214
  assert (
215
  op.start_time is not None
216
  ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time"
217
  assert (
218
  op.end_time is not None
219
  ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time"
 
 
 
 
184
  if all(op.end_time is not None for op in self.ops.values()):
185
  total_time = max(op.end_time for op in self.ops.values())
186
  print(f"\nTotal execution time: {total_time:.2f}")
187
+
 
 
 
 
 
188
  def execute(self):
189
  def execute_op(op: Operation):
190
+ deps = self.get_dependencies(op)
191
  if len(deps) == 0:
192
  op.start_time = 0.0
193
  else:
 
195
  if dep.end_time is None or dep.start_time is None:
196
  execute_op(dep)
197
  op.start_time = max(dep.end_time + gap for dep, gap in deps)
198
+ op.end_time = op.start_time + self.config.get_op_time(
199
  op.op_type, op.stage_id
200
  )
201
 
202
+ op_num = len(self.dev_queues[0].ops)
203
  for i in range(op_num):
204
+ for dev_id in range(self.config.num_devices):
205
+ op = self.dev_queues[dev_id].ops[i]
206
  execute_op(op)
207
 
208
+ for op in self.ops.values():
209
  assert (
210
  op.start_time is not None
211
  ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time"
212
  assert (
213
  op.end_time is not None
214
  ), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time"
215
+
216
+ def get_total_execution_time(self):
217
+ return max(op.end_time for op in self.ops.values())