Spaces:
Running
Running
Update overlapped time.
Browse files- src/execution_model.py +14 -9
src/execution_model.py
CHANGED
@@ -90,7 +90,6 @@ class ScheduleConfig:
|
|
90 |
self.p2p_latency = p2p_latency
|
91 |
self.placement_strategy = placement_strategy
|
92 |
self.split_backward = split_backward
|
93 |
-
self.overlapped_op_times = {}
|
94 |
|
95 |
# Initialize default operation times
|
96 |
if self.split_backward:
|
@@ -152,14 +151,13 @@ class ScheduleConfig:
|
|
152 |
def get_op_time(self, op_type: str, stage_id: int):
|
153 |
# For overlapped operations, extract the original operation types
|
154 |
if op_type.startswith("overlapped_"):
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
if (
|
160 |
-
|
161 |
-
|
162 |
-
return max(self.get_op_time(op_type1, stage_id), self.get_op_time(op_type2, stage_id))
|
163 |
|
164 |
if op_type not in self.op_times:
|
165 |
raise ValueError(f"Invalid operation type: {op_type}")
|
@@ -332,3 +330,10 @@ class Schedule:
|
|
332 |
ideal_time = ideal_time * self.config.num_batches / self.config.num_devices
|
333 |
|
334 |
return (actual_time - ideal_time) / ideal_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
self.p2p_latency = p2p_latency
|
91 |
self.placement_strategy = placement_strategy
|
92 |
self.split_backward = split_backward
|
|
|
93 |
|
94 |
# Initialize default operation times
|
95 |
if self.split_backward:
|
|
|
151 |
def get_op_time(self, op_type: str, stage_id: int):
|
152 |
# For overlapped operations, extract the original operation types
|
153 |
if op_type.startswith("overlapped_"):
|
154 |
+
if op_type in self.op_times and self.op_times[op_type][stage_id]:
|
155 |
+
return self.op_times[op_type][stage_id]
|
156 |
+
else:
|
157 |
+
op_parts = op_type.split("_")[1:]
|
158 |
+
if len(op_parts) >= 2:
|
159 |
+
op_type1, op_type2 = op_parts[0], op_parts[1]
|
160 |
+
return self.get_op_time(op_type1, stage_id) + self.get_op_time(op_type2, stage_id)
|
|
|
161 |
|
162 |
if op_type not in self.op_times:
|
163 |
raise ValueError(f"Invalid operation type: {op_type}")
|
|
|
330 |
ideal_time = ideal_time * self.config.num_batches / self.config.num_devices
|
331 |
|
332 |
return (actual_time - ideal_time) / ideal_time
|
333 |
+
|
334 |
+
def get_device_running_time(self):
|
335 |
+
device_time = [0] * self.config.num_devices
|
336 |
+
for dev_id in range(self.config.num_devices):
|
337 |
+
for op in self.device_queues[dev_id].ops:
|
338 |
+
device_time[dev_id] += op.end_time - op.start_time
|
339 |
+
return device_time
|