Spaces:
Running
Running
from collections import defaultdict | |
from typing import Dict, List, Optional, Union | |
class Operation: | |
"""Operation is a single operation in the pipeline.""" | |
def __init__(self, batch_id: int, stage_id: int, op_type: str): | |
self.batch_id = batch_id | |
self.stage_id = stage_id | |
self.op_type = op_type | |
self.device_id = None | |
self.start_time = None | |
self.end_time = None | |
class DeviceQueue: | |
def __init__(self, stages: List[int], device_id: int): | |
self.stages = stages | |
self.device_id = device_id | |
self.ops = [] # List of operations | |
def add_operation(self, op: Operation): | |
assert op.stage_id in self.stages | |
self.ops.append(op) | |
assert op.device_id is None | |
op.device_id = self.device_id | |
class ScheduleConfig: | |
def __init__( | |
self, | |
num_devices: int, | |
num_stages: int, | |
num_batches: int, | |
p2p_latency: float = 0.0, | |
placement_strategy: str = "standard", | |
split_backward: bool = False, | |
op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None, | |
): | |
self.num_devices = num_devices | |
self.num_stages = num_stages | |
self.num_batches = num_batches | |
self.p2p_latency = p2p_latency | |
self.placement_strategy = placement_strategy | |
self.split_backward = split_backward | |
# Initialize default operation times | |
if self.split_backward: | |
self.op_times = { | |
"forward": 1.0, | |
"backward_D": 1.0, | |
"backward_W": 1.0, | |
} | |
else: | |
self.op_times = { | |
"forward": 1.0, | |
"backward": 2.0, | |
} | |
# Update with user-provided operation times | |
if op_times: | |
for op_type, times in op_times.items(): | |
if isinstance(times, dict): | |
# If a dict is provided, it maps stage_id -> time | |
if op_type not in self.op_times: | |
self.op_times[op_type] = {} | |
elif not isinstance(self.op_times[op_type], dict): | |
# Convert float to dict if needed | |
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)} | |
# Update with provided stage-specific times | |
for stage_id, time in times.items(): | |
if not isinstance(self.op_times[op_type], dict): | |
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)} | |
self.op_times[op_type][stage_id] = time | |
else: | |
# If a float is provided, use same time for all stages | |
self.op_times[op_type] = times | |
assert num_stages % num_devices == 0, "num_stages must be divisible by num_devices" | |
self.num_stages_per_device = num_stages // num_devices | |
self.init_device_to_stages() | |
assert ( | |
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages | |
) | |
def init_device_to_stages(self): | |
if self.placement_strategy == "standard": | |
# Evenly distributed | |
stages_per_device = self.num_stages // self.num_devices | |
self.device_to_stages = defaultdict(list) | |
for i in range(self.num_stages): | |
device_to_put = i // stages_per_device | |
self.device_to_stages[device_to_put].append(i) | |
elif self.placement_strategy == "interleave": | |
self.device_to_stages = defaultdict(list) | |
for i in range(self.num_stages): | |
device_to_put = i % self.num_devices | |
self.device_to_stages[device_to_put].append(i) | |
else: | |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}") | |
def get_op_time(self, op_type: str, stage_id: int): | |
if op_type not in self.op_times: | |
raise ValueError(f"Invalid operation type: {op_type}") | |
times = self.op_times[op_type] | |
if isinstance(times, dict): | |
# If we have stage-specific times, use those | |
if stage_id not in times: | |
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}") | |
return times[stage_id] | |
else: | |
# If we have a single float, use the same value for all stages | |
return times | |
class Schedule: | |
def __init__(self, config: ScheduleConfig): | |
self.ops = {} # (batch_id, stage_id, op_type) -> Operation | |
self.dev_queues: List[DeviceQueue] = [] | |
for dev_id in range(config.num_devices): | |
self.dev_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id)) | |
self.config = config | |
self.init_operations() | |
def init_operations(self): | |
op_types = ["forward", "backward"] | |
if self.config.split_backward: | |
op_types = ["forward", "backward_D", "backward_W"] | |
for batch_id in range(self.config.num_batches): | |
for stage_id in range(self.config.num_stages): | |
for op_type in op_types: | |
self.ops[(batch_id, stage_id, op_type)] = Operation( | |
batch_id, stage_id, op_type | |
) | |
def get_op(self, batch_id: int, stage_id: int, op_type: str): | |
return self.ops[(batch_id, stage_id, op_type)] | |
def get_dependencies(self, op: Operation): | |
deps = [] | |
if op.op_type == "forward": | |
if op.stage_id > 0: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id - 1, "forward"), | |
self.config.p2p_latency, | |
) | |
) | |
if self.config.split_backward: | |
if op.op_type == "backward_D": | |
if op.stage_id < self.config.num_stages - 1: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id + 1, "backward_D"), | |
self.config.p2p_latency, | |
) | |
) | |
elif op.op_type == "backward_W": | |
if op.stage_id < self.config.num_stages - 1: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id, "backward_D"), | |
self.config.p2p_latency, | |
) | |
) | |
else: | |
if op.op_type == "backward": | |
if op.stage_id < self.config.num_stages - 1: | |
deps.append( | |
( | |
self.get_op(op.batch_id, op.stage_id + 1, "backward"), | |
self.config.p2p_latency, | |
) | |
) | |
device_index = self.dev_queues[op.device_id].ops.index(op) | |
if device_index > 0: | |
deps.append((self.dev_queues[op.device_id].ops[device_index - 1], 0.0)) | |
return deps | |
def show(self): | |
"""Display detailed information about the schedule for debugging purposes.""" | |
print("\n=== SCHEDULE DETAILS ===") | |
print(f"Devices: {self.config.num_devices}, Stages: {self.config.num_stages}, Batches: {self.config.num_batches}") | |
print(f"Placement Strategy: {self.config.placement_strategy}") | |
print("\n=== DEVICE QUEUES ===") | |
for dev_id in range(self.config.num_devices): | |
print(f"\nDEVICE {dev_id} (Stages: {self.dev_queues[dev_id].stages}):") | |
print("-" * 80) | |
print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}") | |
print("-" * 80) | |
for op in self.dev_queues[dev_id].ops: | |
op_type = op.op_type | |
start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A" | |
end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A" | |
duration = "N/A" | |
if op.start_time is not None and op.end_time is not None: | |
duration = f"{op.end_time - op.start_time:.2f}" | |
print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}") | |
# Find the total execution time (if timing info is available) | |
if all(op.end_time is not None for op in self.ops.values()): | |
total_time = max(op.end_time for op in self.ops.values()) | |
print(f"\nTotal execution time: {total_time:.2f}") | |
def execute(self): | |
def execute_op(op: Operation): | |
deps = self.get_dependencies(op) | |
if len(deps) == 0: | |
op.start_time = 0.0 | |
else: | |
for dep, gap in deps: | |
if dep.end_time is None or dep.start_time is None: | |
execute_op(dep) | |
op.start_time = max(dep.end_time + gap for dep, gap in deps) | |
op.end_time = op.start_time + self.config.get_op_time( | |
op.op_type, op.stage_id | |
) | |
op_num = len(self.dev_queues[0].ops) | |
for i in range(op_num): | |
for dev_id in range(self.config.num_devices): | |
op = self.dev_queues[dev_id].ops[i] | |
execute_op(op) | |
for op in self.ops.values(): | |
assert ( | |
op.start_time is not None | |
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time" | |
assert ( | |
op.end_time is not None | |
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time" | |
def get_total_execution_time(self): | |
return max(op.end_time for op in self.ops.values()) | |
def get_bubble_rate(self): | |
actual_time = self.get_total_execution_time() | |
ideal_time = 0 | |
for stage_id in range(self.config.num_stages): | |
for op_type in ["forward", "backward"]: | |
ideal_time += self.config.get_op_time(op_type, stage_id) | |
ideal_time = ideal_time * self.config.num_batches / self.config.num_devices | |
return (actual_time - ideal_time) / ideal_time | |