Victarry commited on
Commit
2623d17
·
1 Parent(s): dc262c1

Add formula for DualPipe

Browse files
Files changed (1) hide show
  1. formula.py +82 -0
formula.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PP schedule config
2
+ from src.execution_model import ScheduleConfig
3
+ from src.strategies import generate_dualpipe_v_schedule
4
+
5
+
6
+ p = 4 # PP size
7
+ v = 2 # number of virtual stages
8
+ m = 10 # total microbatches
9
+
10
+ # stage time config
11
+ F = 2.0 # forward time in one PP rank for all stages
12
+ W = 2.0 # backward_W time in one PP rank for all stages
13
+ D = 2.0 # backward_D time in one PP rank for all stages
14
+ B = W + D # backward time in one PP rank for all stages
15
+ FwB = 6 # overlapped forward backward time in one PP rank for all stages
16
+
17
+ op_times = {
18
+ "forward": F,
19
+ "backward": B,
20
+ "backward_D": D,
21
+ "backward_W": W,
22
+ "overlapped_forward_backward": FwB
23
+ }
24
+
25
+ def dualpipe_v_execution_time_by_formula():
26
+ # Formula from the image
27
+ item_1 = ((p - 1) / 2) * F
28
+ item_2 = (p + 0.5) * F + (p / 2 + 1) * B
29
+ item_3 = (m - (p / 2 + 1)) * FwB
30
+ print(f"item_1: {item_1}, item_2: {item_2}, item_3: {item_3}")
31
+ total_time = item_1 + item_2 + item_3
32
+ return total_time
33
+
34
+ def dualpipe_v_execution_time_by_formula_detailed():
35
+ # Correct formula
36
+ local_F = F / 2
37
+ local_B = B / 2
38
+ local_D = D / 2
39
+ local_W = W / 2
40
+ local_FwB = FwB / 2
41
+
42
+ forward_bubble = (p - 1) * local_F # forward bubble
43
+ forward_time = 2 * p * local_F
44
+ overlapped_time = (2 * (m-p)-1) * local_FwB + (p-1) * local_FwB
45
+ backward_time = (2*p-1) * local_D + local_W
46
+ other_time = 2 * local_B + local_F
47
+
48
+ active_time = (2 * (m-p)-1) * local_FwB + (2*p+1) * (local_F + local_B)
49
+ total_time = forward_bubble + forward_time + overlapped_time + backward_time + other_time
50
+ bubble_time = total_time - active_time
51
+
52
+ assert bubble_time == (p-1)*(local_FwB + local_B - 3*local_W)
53
+
54
+ return total_time
55
+
56
+ def dualpipe_v_execution_time_by_emulate():
57
+ op_times_per_stage = {
58
+ "forward": F / 2,
59
+ "backward": B / 2,
60
+ "backward_D": D / 2,
61
+ "backward_W": W / 2,
62
+ "overlapped_forward_backward": FwB / 2
63
+ }
64
+ print(f"op_times_per_stage: {op_times_per_stage}")
65
+ dualpipe_schedule_config = ScheduleConfig(
66
+ num_devices=p,
67
+ num_stages=p*2,
68
+ num_batches=m,
69
+ p2p_latency=0.0,
70
+ op_times=op_times_per_stage,
71
+ split_backward=True,
72
+ placement_strategy="dualpipe_v",
73
+ )
74
+
75
+ dual_pipe_schedule = generate_dualpipe_v_schedule(dualpipe_schedule_config)
76
+
77
+ dual_pipe_schedule.execute()
78
+
79
+ return dual_pipe_schedule.get_total_execution_time()
80
+
81
+ print(f"DualPipe-V by emulate: {dualpipe_v_execution_time_by_emulate()}")
82
+ print(f"DualPipe-V by formula detailed: {dualpipe_v_execution_time_by_formula_detailed()}")