Victarry commited on
Commit
370fc5b
·
1 Parent(s): e178784
Files changed (1) hide show
  1. pipeline.py +10 -10
pipeline.py CHANGED
@@ -1,11 +1,8 @@
1
- import matplotlib.pyplot as plt
2
- import numpy as np
3
  import argparse
4
  import json
5
  import yaml
6
  import os
7
- from matplotlib.patches import Rectangle
8
- from typing import List, Tuple, Dict, Literal
9
 
10
  # Import visualization function from the new module
11
  from visualizer import visualize_pipeline_parallelism
@@ -205,7 +202,7 @@ def calculate_operation_timing(
205
  return schedule
206
 
207
 
208
- def get_bubble_rate(schedule: Dict[int, List[Dict]]):
209
  num_stages = len(schedule)
210
 
211
  max_time = 0
@@ -215,7 +212,6 @@ def get_bubble_rate(schedule: Dict[int, List[Dict]]):
215
  if end_time > max_time:
216
  max_time = end_time
217
 
218
- print(f"Max time: {max_time}")
219
  total_execution_time = max_time * num_stages
220
 
221
  total_computation_time = 0
@@ -231,7 +227,11 @@ def get_bubble_rate(schedule: Dict[int, List[Dict]]):
231
  bubble_rate = (
232
  total_execution_time - total_computation_time
233
  ) / total_computation_time
234
- return bubble_rate
 
 
 
 
235
 
236
 
237
  def read_config_file(config_path):
@@ -476,12 +476,12 @@ def main():
476
  )
477
 
478
  # Analyze the schedule
479
- bubble_rate = get_bubble_rate(schedule)
480
- print(f"Bubble rate: {bubble_rate:.4f}")
481
 
482
  return {
483
  "schedule": schedule,
484
- "bubble_rate": bubble_rate,
485
  "num_stages": num_stages,
486
  "num_batches": num_batches,
487
  }
 
 
 
1
  import argparse
2
  import json
3
  import yaml
4
  import os
5
+ from typing import List, Dict
 
6
 
7
  # Import visualization function from the new module
8
  from visualizer import visualize_pipeline_parallelism
 
202
  return schedule
203
 
204
 
205
+ def get_schedule_info(schedule: Dict[int, List[Dict]]):
206
  num_stages = len(schedule)
207
 
208
  max_time = 0
 
212
  if end_time > max_time:
213
  max_time = end_time
214
 
 
215
  total_execution_time = max_time * num_stages
216
 
217
  total_computation_time = 0
 
227
  bubble_rate = (
228
  total_execution_time - total_computation_time
229
  ) / total_computation_time
230
+
231
+ return {
232
+ "bubble_rate": f"{bubble_rate*100:.2f}%",
233
+ "execution_time": f"{max_time / 1000:.2f} s",
234
+ }
235
 
236
 
237
  def read_config_file(config_path):
 
476
  )
477
 
478
  # Analyze the schedule
479
+ schedule_info = get_schedule_info(schedule)
480
+ print(schedule_info)
481
 
482
  return {
483
  "schedule": schedule,
484
+ "schedule_info": schedule_info,
485
  "num_stages": num_stages,
486
  "num_batches": num_batches,
487
  }