Victarry commited on
Commit
7b84c26
·
1 Parent(s): d4ea23e

Update visualization.

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. pipeline_1f1b.png +2 -2
  3. visualizer.py +14 -6
README.md CHANGED
@@ -4,7 +4,7 @@ This tool simulates and visualizes pipeline parallelism scheduling strategies, f
4
 
5
  ## Usage
6
 
7
- ### Example Output
8
 
9
  ```bash
10
  python pipeline.py --num-stages 4 --num-batches 8
 
4
 
5
  ## Usage
6
 
7
+ ### Example
8
 
9
  ```bash
10
  python pipeline.py --num-stages 4 --num-batches 8
pipeline_1f1b.png CHANGED

Git LFS Details

  • SHA256: 7d0431d9bb7235a04456f1753074d76752fd9b64c52c32c15f12c586840ea5fe
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB

Git LFS Details

  • SHA256: 51fb99dc001443186b446c023848fe8e98362dfc4628e9be853f2959c2eedd33
  • Pointer size: 131 Bytes
  • Size of remote file: 168 kB
visualizer.py CHANGED
@@ -43,12 +43,13 @@ def visualize_pipeline_parallelism(
43
 
44
  # Plot the schedule
45
  for device_idx, device in enumerate(schedule):
 
46
  for task in schedule[device]:
47
  color = forward_color if task["type"] == "forward" else backward_color
48
  rect = Rectangle(
49
- (task["start_time"], device_idx),
50
  task["duration"],
51
- 0.8,
52
  edgecolor="black",
53
  facecolor=color,
54
  alpha=0.8,
@@ -58,7 +59,7 @@ def visualize_pipeline_parallelism(
58
  # Add text (batch number)
59
  ax.text(
60
  task["start_time"] + task["duration"] / 2,
61
- device_idx + 0.4,
62
  str(task["batch"]),
63
  ha="center",
64
  va="center",
@@ -69,11 +70,18 @@ def visualize_pipeline_parallelism(
69
 
70
  # Set axis limits and labels
71
  ax.set_xlim(0, max_time * 1.05)
72
- ax.set_ylim(-0.2, num_stages + 0.2)
73
- ax.set_yticks(np.arange(num_stages) + 0.4)
74
- ax.set_yticklabels([f"Device {i+1}" for i in range(num_stages)])
 
 
 
75
  ax.set_xlabel("Time")
76
  ax.set_title(f"Pipeline Parallelism Schedule ({schedule_type})")
 
 
 
 
77
 
78
  # Add a legend
79
  forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
 
43
 
44
  # Plot the schedule
45
  for device_idx, device in enumerate(schedule):
46
+ device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
47
  for task in schedule[device]:
48
  color = forward_color if task["type"] == "forward" else backward_color
49
  rect = Rectangle(
50
+ (task["start_time"], device_idx_reversed),
51
  task["duration"],
52
+ 1.0, # Use full height to completely remove gaps
53
  edgecolor="black",
54
  facecolor=color,
55
  alpha=0.8,
 
59
  # Add text (batch number)
60
  ax.text(
61
  task["start_time"] + task["duration"] / 2,
62
+ device_idx_reversed + 0.5, # Center text in the middle of full-height rectangle
63
  str(task["batch"]),
64
  ha="center",
65
  va="center",
 
70
 
71
  # Set axis limits and labels
72
  ax.set_xlim(0, max_time * 1.05)
73
+ ax.set_ylim(-0.05, num_stages + 0.05) # Keep the same tight padding
74
+ ax.set_yticks(np.arange(num_stages) + 0.5) # Center ticks in the middle of each stage
75
+ # Reverse the order: Device 1 at the top, highest number at the bottom
76
+ device_labels = [f"Device {i+1}" for i in range(num_stages)]
77
+ device_labels.reverse() # Reverse to put Device 1 at the top
78
+ ax.set_yticklabels(device_labels)
79
  ax.set_xlabel("Time")
80
  ax.set_title(f"Pipeline Parallelism Schedule ({schedule_type})")
81
+
82
+ # Remove the outer frame/border
83
+ for spine in ax.spines.values():
84
+ spine.set_visible(False)
85
 
86
  # Add a legend
87
  forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)