Victarry commited on
Commit
8854c6a
·
1 Parent(s): 7b84c26

Change figure style to ZBPP.

Browse files
Files changed (2) hide show
  1. pipeline_1f1b.png +2 -2
  2. visualizer.py +58 -22
pipeline_1f1b.png CHANGED

Git LFS Details

  • SHA256: 51fb99dc001443186b446c023848fe8e98362dfc4628e9be853f2959c2eedd33
  • Pointer size: 131 Bytes
  • Size of remote file: 168 kB

Git LFS Details

  • SHA256: ff047349dfa8f855aca47e233be6a5b12b45441c7f45bbe69509d0602dc1a127
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
visualizer.py CHANGED
@@ -15,17 +15,18 @@ def visualize_pipeline_parallelism(
15
  Args:
16
  schedule: Dictionary mapping device IDs to lists of tasks.
17
  Each task is a dictionary with keys:
18
- - 'type': 'forward' or 'backward'
19
  - 'batch': batch number
20
  - 'start_time': start time of the task
21
  - 'duration': duration of the task
22
  schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
23
  output_file: Path to save the visualization
24
  """
25
- # Colors for forward and backward passes
26
  forward_color = "royalblue"
27
- backward_color = "lightgreen"
28
- empty_color = "lightgray"
 
29
 
30
  # Find the number of stages (devices)
31
  num_stages = len(schedule)
@@ -39,63 +40,98 @@ def visualize_pipeline_parallelism(
39
  max_time = end_time
40
 
41
  # Create figure and axis
42
- fig, ax = plt.subplots(figsize=(15, 5))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Plot the schedule
45
  for device_idx, device in enumerate(schedule):
46
  device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
47
  for task in schedule[device]:
48
- color = forward_color if task["type"] == "forward" else backward_color
 
 
 
 
 
 
 
 
 
 
49
  rect = Rectangle(
50
  (task["start_time"], device_idx_reversed),
51
  task["duration"],
52
- 1.0, # Use full height to completely remove gaps
53
  edgecolor="black",
54
  facecolor=color,
55
- alpha=0.8,
56
  )
57
  ax.add_patch(rect)
58
 
59
  # Add text (batch number)
60
  ax.text(
61
  task["start_time"] + task["duration"] / 2,
62
- device_idx_reversed + 0.5, # Center text in the middle of full-height rectangle
63
  str(task["batch"]),
64
  ha="center",
65
  va="center",
66
  fontsize=10,
67
  fontweight="bold",
68
- color="white" if task["type"] == "forward" else "black",
69
  )
70
 
71
  # Set axis limits and labels
72
- ax.set_xlim(0, max_time * 1.05)
73
- ax.set_ylim(-0.05, num_stages + 0.05) # Keep the same tight padding
74
- ax.set_yticks(np.arange(num_stages) + 0.5) # Center ticks in the middle of each stage
 
75
  # Reverse the order: Device 1 at the top, highest number at the bottom
76
  device_labels = [f"Device {i+1}" for i in range(num_stages)]
77
  device_labels.reverse() # Reverse to put Device 1 at the top
78
  ax.set_yticklabels(device_labels)
79
- ax.set_xlabel("Time")
80
- ax.set_title(f"Pipeline Parallelism Schedule ({schedule_type})")
 
 
 
 
 
 
 
81
 
82
  # Remove the outer frame/border
83
  for spine in ax.spines.values():
84
  spine.set_visible(False)
85
 
86
- # Add a legend
87
  forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
88
  backward_patch = Rectangle((0, 0), 1, 1, facecolor=backward_color)
89
- ax.legend(
90
- [forward_patch, backward_patch],
91
- ["Forward Pass", "Backward Pass"],
 
 
92
  loc="upper center",
93
  bbox_to_anchor=(0.5, -0.15),
94
- ncol=2,
 
95
  )
96
 
97
- # Add grid
98
- ax.grid(True, linestyle="--", alpha=0.7)
99
 
100
  # Save the figure
101
  plt.tight_layout()
 
15
  Args:
16
  schedule: Dictionary mapping device IDs to lists of tasks.
17
  Each task is a dictionary with keys:
18
+ - 'type': 'forward', 'backward', or 'optimizer'
19
  - 'batch': batch number
20
  - 'start_time': start time of the task
21
  - 'duration': duration of the task
22
  schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
23
  output_file: Path to save the visualization
24
  """
25
+ # Colors for task types
26
  forward_color = "royalblue"
27
+ backward_color = "sandybrown" # Changed to match the reference image
28
+ optimizer_color = "#FFEFCF" # Light beige for optimizer steps
29
+ empty_color = "whitesmoke" # Very light gray for empty cells
30
 
31
  # Find the number of stages (devices)
32
  num_stages = len(schedule)
 
40
  max_time = end_time
41
 
42
  # Create figure and axis
43
+ fig, ax = plt.subplots(figsize=(15, 4))
44
+
45
+ # Create an empty grid with light gray color
46
+ for device_idx in range(num_stages):
47
+ device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
48
+ for t in range(int(max_time) + 1):
49
+ rect = Rectangle(
50
+ (t, device_idx_reversed),
51
+ 1.0,
52
+ 1.0,
53
+ edgecolor="lightgray",
54
+ facecolor=empty_color,
55
+ linewidth=0.5,
56
+ )
57
+ ax.add_patch(rect)
58
 
59
  # Plot the schedule
60
  for device_idx, device in enumerate(schedule):
61
  device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
62
  for task in schedule[device]:
63
+ # Determine task color
64
+ if task["type"] == "forward":
65
+ color = forward_color
66
+ text_color = "white"
67
+ elif task["type"] == "backward":
68
+ color = backward_color
69
+ text_color = "black"
70
+ else: # optimizer or any other type
71
+ color = optimizer_color
72
+ text_color = "black"
73
+
74
  rect = Rectangle(
75
  (task["start_time"], device_idx_reversed),
76
  task["duration"],
77
+ 1.0,
78
  edgecolor="black",
79
  facecolor=color,
80
+ linewidth=0.5,
81
  )
82
  ax.add_patch(rect)
83
 
84
  # Add text (batch number)
85
  ax.text(
86
  task["start_time"] + task["duration"] / 2,
87
+ device_idx_reversed + 0.5,
88
  str(task["batch"]),
89
  ha="center",
90
  va="center",
91
  fontsize=10,
92
  fontweight="bold",
93
+ color=text_color,
94
  )
95
 
96
  # Set axis limits and labels
97
+ ax.set_xlim(0, max_time + 0.5)
98
+ ax.set_ylim(-0.5, num_stages + 0.5)
99
+ ax.set_yticks(np.arange(num_stages) + 0.5)
100
+
101
  # Reverse the order: Device 1 at the top, highest number at the bottom
102
  device_labels = [f"Device {i+1}" for i in range(num_stages)]
103
  device_labels.reverse() # Reverse to put Device 1 at the top
104
  ax.set_yticklabels(device_labels)
105
+
106
+ # Add "Time" label and arrow at the bottom
107
+ arrow_y = -0.4
108
+ ax.text(0.5, arrow_y, "Time", ha="right", va="center", fontsize=10)
109
+ ax.annotate("", xy=(2, arrow_y), xytext=(1, arrow_y),
110
+ arrowprops=dict(arrowstyle="->", lw=1))
111
+
112
+ # Remove the x-axis ticks
113
+ ax.set_xticks([])
114
 
115
  # Remove the outer frame/border
116
  for spine in ax.spines.values():
117
  spine.set_visible(False)
118
 
119
+ # Add a legend - using 3 parts like in the reference image
120
  forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
121
  backward_patch = Rectangle((0, 0), 1, 1, facecolor=backward_color)
122
+ optimizer_patch = Rectangle((0, 0), 1, 1, facecolor=optimizer_color)
123
+
124
+ legend = ax.legend(
125
+ [forward_patch, backward_patch, optimizer_patch],
126
+ ["Forward", "Backward", "Optimizer step"],
127
  loc="upper center",
128
  bbox_to_anchor=(0.5, -0.15),
129
+ ncol=3,
130
+ frameon=False,
131
  )
132
 
133
+ # Turn off grid
134
+ ax.grid(False)
135
 
136
  # Save the figure
137
  plt.tight_layout()