Victarry commited on
Commit
a9586a0
·
1 Parent(s): f4c58ee

Update overlapped time.

Browse files
Files changed (1) hide show
  1. src/execution_model.py +14 -9
src/execution_model.py CHANGED
@@ -90,7 +90,6 @@ class ScheduleConfig:
90
  self.p2p_latency = p2p_latency
91
  self.placement_strategy = placement_strategy
92
  self.split_backward = split_backward
93
- self.overlapped_op_times = {}
94
 
95
  # Initialize default operation times
96
  if self.split_backward:
@@ -152,14 +151,13 @@ class ScheduleConfig:
152
  def get_op_time(self, op_type: str, stage_id: int):
153
  # For overlapped operations, extract the original operation types
154
  if op_type.startswith("overlapped_"):
155
- op_parts = op_type.split("_")[1:]
156
- if len(op_parts) >= 2:
157
- op_type1, op_type2 = op_parts[0], op_parts[1]
158
- # Check if we have a specific time for this combination
159
- if (op_type1, op_type2) in self.overlapped_op_times:
160
- return self.overlapped_op_times[(op_type1, op_type2)]
161
- # Otherwise, use the max of individual times
162
- return max(self.get_op_time(op_type1, stage_id), self.get_op_time(op_type2, stage_id))
163
 
164
  if op_type not in self.op_times:
165
  raise ValueError(f"Invalid operation type: {op_type}")
@@ -332,3 +330,10 @@ class Schedule:
332
  ideal_time = ideal_time * self.config.num_batches / self.config.num_devices
333
 
334
  return (actual_time - ideal_time) / ideal_time
 
 
 
 
 
 
 
 
90
  self.p2p_latency = p2p_latency
91
  self.placement_strategy = placement_strategy
92
  self.split_backward = split_backward
 
93
 
94
  # Initialize default operation times
95
  if self.split_backward:
 
151
  def get_op_time(self, op_type: str, stage_id: int):
152
  # For overlapped operations, extract the original operation types
153
  if op_type.startswith("overlapped_"):
154
+ if op_type in self.op_times and self.op_times[op_type][stage_id]:
155
+ return self.op_times[op_type][stage_id]
156
+ else:
157
+ op_parts = op_type.split("_")[1:]
158
+ if len(op_parts) >= 2:
159
+ op_type1, op_type2 = op_parts[0], op_parts[1]
160
+ return self.get_op_time(op_type1, stage_id) + self.get_op_time(op_type2, stage_id)
 
161
 
162
  if op_type not in self.op_times:
163
  raise ValueError(f"Invalid operation type: {op_type}")
 
330
  ideal_time = ideal_time * self.config.num_batches / self.config.num_devices
331
 
332
  return (actual_time - ideal_time) / ideal_time
333
+
334
+ def get_device_running_time(self):
335
+ device_time = [0] * self.config.num_devices
336
+ for dev_id in range(self.config.num_devices):
337
+ for op in self.device_queues[dev_id].ops:
338
+ device_time[dev_id] += op.end_time - op.start_time
339
+ return device_time