Victarry commited on
Commit
c684983
·
1 Parent(s): 22a494c

Add microbatch_group_size_per_vp_stage as configurable.

Browse files
Files changed (5) hide show
  1. README.md +2 -0
  2. conf/config.yaml +1 -0
  3. main.py +1 -0
  4. src/execution_model.py +5 -0
  5. src/strategies.py +1 -1
README.md CHANGED
@@ -72,6 +72,8 @@ uv run python main.py strategy=interleave num_devices=4 num_stages=8 num_batches
72
  ```
73
  ![interleave](assets/interleave_1f1b.png)
74
 
 
 
75
  ### Running for ZB-1P strategy:
76
  ```bash
77
  uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
 
72
  ```
73
  ![interleave](assets/interleave_1f1b.png)
74
 
75
+ You can optionally setting `microbatch_group_size_per_vp_stage`.
76
+
77
  ### Running for ZB-1P strategy:
78
  ```bash
79
  uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
conf/config.yaml CHANGED
@@ -5,6 +5,7 @@ num_batches: 8
5
  visualization_port: 8050
6
  strategy: "1f1b" # Options: "1f1b", "interleave"
7
  p2p_latency: 0.0
 
8
 
9
  # Operation time configurations
10
  op_times:
 
5
  visualization_port: 8050
6
  strategy: "1f1b" # Options: "1f1b", "interleave"
7
  p2p_latency: 0.0
8
+ microbatch_group_size_per_vp_stage: null
9
 
10
  # Operation time configurations
11
  op_times:
main.py CHANGED
@@ -71,6 +71,7 @@ def run_interleave(cfg: DictConfig) -> None:
71
  p2p_latency=cfg.p2p_latency,
72
  placement_strategy="interleave",
73
  op_times=op_times,
 
74
  )
75
  schedule = generate_1f1b_interleave_schedule(schedule_config)
76
  schedule.execute()
 
71
  p2p_latency=cfg.p2p_latency,
72
  placement_strategy="interleave",
73
  op_times=op_times,
74
+ microbatch_group_size_per_vp_stage=cfg.microbatch_group_size_per_vp_stage,
75
  )
76
  schedule = generate_1f1b_interleave_schedule(schedule_config)
77
  schedule.execute()
src/execution_model.py CHANGED
@@ -83,6 +83,7 @@ class ScheduleConfig:
83
  placement_strategy: str = "standard",
84
  split_backward: bool = False,
85
  op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
 
86
  ):
87
  self.num_devices = num_devices
88
  self.num_stages = num_stages
@@ -90,6 +91,10 @@ class ScheduleConfig:
90
  self.p2p_latency = p2p_latency
91
  self.placement_strategy = placement_strategy
92
  self.split_backward = split_backward
 
 
 
 
93
 
94
  # Initialize default operation times
95
  if self.split_backward:
 
83
  placement_strategy: str = "standard",
84
  split_backward: bool = False,
85
  op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
86
+ microbatch_group_size_per_vp_stage: Optional[int] = None,
87
  ):
88
  self.num_devices = num_devices
89
  self.num_stages = num_stages
 
91
  self.p2p_latency = p2p_latency
92
  self.placement_strategy = placement_strategy
93
  self.split_backward = split_backward
94
+ if microbatch_group_size_per_vp_stage is None:
95
+ self.microbatch_group_size_per_vp_stage = num_devices
96
+ else:
97
+ self.microbatch_group_size_per_vp_stage = microbatch_group_size_per_vp_stage
98
 
99
  # Initialize default operation times
100
  if self.split_backward:
src/strategies.py CHANGED
@@ -244,7 +244,7 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
244
  schedule = Schedule(config)
245
 
246
  for device_id in range(config.num_devices):
247
- microbatch_group_size_per_vp_stage = config.num_devices
248
  num_warmup_microbatches = _get_pp_rank_microbatches(
249
  config.num_batches,
250
  config.num_devices,
 
244
  schedule = Schedule(config)
245
 
246
  for device_id in range(config.num_devices):
247
+ microbatch_group_size_per_vp_stage = config.microbatch_group_size_per_vp_stage
248
  num_warmup_microbatches = _get_pp_rank_microbatches(
249
  config.num_batches,
250
  config.num_devices,