Victarry commited on
Commit
22a494c
·
1 Parent(s): 6688c1a

Update dualpipe implementaion

Browse files
Files changed (1) hide show
  1. src/strategies.py +14 -11
src/strategies.py CHANGED
@@ -589,17 +589,20 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
589
  # Step 4 (Main step): nF0B1F1B0
590
  step_4_count = half_num_chunks - num_devices + half_rank + 1
591
  for i in range(step_4_count):
592
- # if i == 0 and is_middle_rank:
593
- # Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
594
- # _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
595
- # _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
596
- # _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
597
- # else:
598
- # Overlap F0 and B1_D, then schedule W1
599
- _schedule_forward_backward_chunk(
600
- device_id, 0, 1, is_in_second_half
601
- ) # F0+B1
602
-
 
 
 
603
  # Overlap F1 and B0_D, then schedule W0
604
  _schedule_forward_backward_chunk(
605
  device_id, 1, 0, is_in_second_half
 
589
  # Step 4 (Main step): nF0B1F1B0
590
  step_4_count = half_num_chunks - num_devices + half_rank + 1
591
  for i in range(step_4_count):
592
+ if i == 0:
593
+ if is_middle_rank:
594
+ # Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
595
+ _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
596
+ _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
597
+ else:
598
+ # Overlap F0 and B1_D, then schedule W1
599
+ _schedule_forward_backward_chunk(
600
+ device_id, 0, 1, is_in_second_half
601
+ ) # F0+B1
602
+ else:
603
+ _schedule_forward_backward_chunk(
604
+ device_id, 0, 1, is_in_second_half
605
+ ) # F0+B1
606
  # Overlap F1 and B0_D, then schedule W0
607
  _schedule_forward_backward_chunk(
608
  device_id, 1, 0, is_in_second_half