Spaces:
Running
Running
Update visualization.
Browse files- README.md +1 -1
- pipeline_1f1b.png +2 -2
- visualizer.py +14 -6
README.md
CHANGED
@@ -4,7 +4,7 @@ This tool simulates and visualizes pipeline parallelism scheduling strategies, f
|
|
4 |
|
5 |
## Usage
|
6 |
|
7 |
-
### Example
|
8 |
|
9 |
```bash
|
10 |
python pipeline.py --num-stages 4 --num-batches 8
|
|
|
4 |
|
5 |
## Usage
|
6 |
|
7 |
+
### Example
|
8 |
|
9 |
```bash
|
10 |
python pipeline.py --num-stages 4 --num-batches 8
|
pipeline_1f1b.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
visualizer.py
CHANGED
@@ -43,12 +43,13 @@ def visualize_pipeline_parallelism(
|
|
43 |
|
44 |
# Plot the schedule
|
45 |
for device_idx, device in enumerate(schedule):
|
|
|
46 |
for task in schedule[device]:
|
47 |
color = forward_color if task["type"] == "forward" else backward_color
|
48 |
rect = Rectangle(
|
49 |
-
(task["start_time"],
|
50 |
task["duration"],
|
51 |
-
0
|
52 |
edgecolor="black",
|
53 |
facecolor=color,
|
54 |
alpha=0.8,
|
@@ -58,7 +59,7 @@ def visualize_pipeline_parallelism(
|
|
58 |
# Add text (batch number)
|
59 |
ax.text(
|
60 |
task["start_time"] + task["duration"] / 2,
|
61 |
-
|
62 |
str(task["batch"]),
|
63 |
ha="center",
|
64 |
va="center",
|
@@ -69,11 +70,18 @@ def visualize_pipeline_parallelism(
|
|
69 |
|
70 |
# Set axis limits and labels
|
71 |
ax.set_xlim(0, max_time * 1.05)
|
72 |
-
ax.set_ylim(-0.
|
73 |
-
ax.set_yticks(np.arange(num_stages) + 0.
|
74 |
-
|
|
|
|
|
|
|
75 |
ax.set_xlabel("Time")
|
76 |
ax.set_title(f"Pipeline Parallelism Schedule ({schedule_type})")
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# Add a legend
|
79 |
forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
|
|
|
43 |
|
44 |
# Plot the schedule
|
45 |
for device_idx, device in enumerate(schedule):
|
46 |
+
device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
|
47 |
for task in schedule[device]:
|
48 |
color = forward_color if task["type"] == "forward" else backward_color
|
49 |
rect = Rectangle(
|
50 |
+
(task["start_time"], device_idx_reversed),
|
51 |
task["duration"],
|
52 |
+
1.0, # Use full height to completely remove gaps
|
53 |
edgecolor="black",
|
54 |
facecolor=color,
|
55 |
alpha=0.8,
|
|
|
59 |
# Add text (batch number)
|
60 |
ax.text(
|
61 |
task["start_time"] + task["duration"] / 2,
|
62 |
+
device_idx_reversed + 0.5, # Center text in the middle of full-height rectangle
|
63 |
str(task["batch"]),
|
64 |
ha="center",
|
65 |
va="center",
|
|
|
70 |
|
71 |
# Set axis limits and labels
|
72 |
ax.set_xlim(0, max_time * 1.05)
|
73 |
+
ax.set_ylim(-0.05, num_stages + 0.05) # Keep the same tight padding
|
74 |
+
ax.set_yticks(np.arange(num_stages) + 0.5) # Center ticks in the middle of each stage
|
75 |
+
# Reverse the order: Device 1 at the top, highest number at the bottom
|
76 |
+
device_labels = [f"Device {i+1}" for i in range(num_stages)]
|
77 |
+
device_labels.reverse() # Reverse to put Device 1 at the top
|
78 |
+
ax.set_yticklabels(device_labels)
|
79 |
ax.set_xlabel("Time")
|
80 |
ax.set_title(f"Pipeline Parallelism Schedule ({schedule_type})")
|
81 |
+
|
82 |
+
# Remove the outer frame/border
|
83 |
+
for spine in ax.spines.values():
|
84 |
+
spine.set_visible(False)
|
85 |
|
86 |
# Add a legend
|
87 |
forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
|