Spaces:
Running
Running
Update.
Browse files- pipeline.py +10 -10
pipeline.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
-
import numpy as np
|
3 |
import argparse
|
4 |
import json
|
5 |
import yaml
|
6 |
import os
|
7 |
-
from
|
8 |
-
from typing import List, Tuple, Dict, Literal
|
9 |
|
10 |
# Import visualization function from the new module
|
11 |
from visualizer import visualize_pipeline_parallelism
|
@@ -205,7 +202,7 @@ def calculate_operation_timing(
|
|
205 |
return schedule
|
206 |
|
207 |
|
208 |
-
def
|
209 |
num_stages = len(schedule)
|
210 |
|
211 |
max_time = 0
|
@@ -215,7 +212,6 @@ def get_bubble_rate(schedule: Dict[int, List[Dict]]):
|
|
215 |
if end_time > max_time:
|
216 |
max_time = end_time
|
217 |
|
218 |
-
print(f"Max time: {max_time}")
|
219 |
total_execution_time = max_time * num_stages
|
220 |
|
221 |
total_computation_time = 0
|
@@ -231,7 +227,11 @@ def get_bubble_rate(schedule: Dict[int, List[Dict]]):
|
|
231 |
bubble_rate = (
|
232 |
total_execution_time - total_computation_time
|
233 |
) / total_computation_time
|
234 |
-
|
|
|
|
|
|
|
|
|
235 |
|
236 |
|
237 |
def read_config_file(config_path):
|
@@ -476,12 +476,12 @@ def main():
|
|
476 |
)
|
477 |
|
478 |
# Analyze the schedule
|
479 |
-
|
480 |
-
print(
|
481 |
|
482 |
return {
|
483 |
"schedule": schedule,
|
484 |
-
"
|
485 |
"num_stages": num_stages,
|
486 |
"num_batches": num_batches,
|
487 |
}
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import yaml
|
4 |
import os
|
5 |
+
from typing import List, Dict
|
|
|
6 |
|
7 |
# Import visualization function from the new module
|
8 |
from visualizer import visualize_pipeline_parallelism
|
|
|
202 |
return schedule
|
203 |
|
204 |
|
205 |
+
def get_schedule_info(schedule: Dict[int, List[Dict]]):
|
206 |
num_stages = len(schedule)
|
207 |
|
208 |
max_time = 0
|
|
|
212 |
if end_time > max_time:
|
213 |
max_time = end_time
|
214 |
|
|
|
215 |
total_execution_time = max_time * num_stages
|
216 |
|
217 |
total_computation_time = 0
|
|
|
227 |
bubble_rate = (
|
228 |
total_execution_time - total_computation_time
|
229 |
) / total_computation_time
|
230 |
+
|
231 |
+
return {
|
232 |
+
"bubble_rate": f"{bubble_rate*100:.2f}%",
|
233 |
+
"execution_time": f"{max_time / 1000:.2f} s",
|
234 |
+
}
|
235 |
|
236 |
|
237 |
def read_config_file(config_path):
|
|
|
476 |
)
|
477 |
|
478 |
# Analyze the schedule
|
479 |
+
schedule_info = get_schedule_info(schedule)
|
480 |
+
print(schedule_info)
|
481 |
|
482 |
return {
|
483 |
"schedule": schedule,
|
484 |
+
"schedule_info": schedule_info,
|
485 |
"num_stages": num_stages,
|
486 |
"num_batches": num_batches,
|
487 |
}
|