Spaces:
Running
Running
import matplotlib.pyplot as plt | |
import numpy as np | |
import argparse | |
import json | |
import yaml | |
import os | |
from matplotlib.patches import Rectangle | |
from typing import List, Tuple, Dict, Literal | |
# Import visualization function from the new module | |
from visualizer import visualize_pipeline_parallelism | |
def create_1f1b_schedule( | |
num_stages: int, | |
num_batches: int, | |
forward_times: List[float], | |
backward_times: List[float], | |
p2p_time: float = 0.0, | |
) -> Dict[int, List[Dict]]: | |
""" | |
Create a 1F1B (One-Forward-One-Backward) schedule for pipeline parallelism. | |
This implementation takes a data-centric approach: | |
1. First determine the operation sequence for each pipeline stage (which microbatch to process when) | |
2. Then calculate timing based on dependencies between operations | |
The 1F1B pattern has three phases: | |
- Warmup: Forward passes for first num_stages microbatches | |
- Steady state: Alternating between forward and backward passes | |
- Cooldown: Backward passes for remaining microbatches | |
Returns: | |
A dictionary mapping device IDs to lists of tasks. | |
Each task is a dictionary with keys: | |
- 'type': 'forward' or 'backward' | |
- 'batch': batch number | |
- 'start_time': start time of the task | |
- 'duration': duration of the task | |
""" | |
# Initialize empty schedule | |
schedule = {stage: [] for stage in range(num_stages)} | |
# Step 1: Determine operation sequence for each stage | |
# This will generate the sequence of operations (forward/backward on which microbatch) | |
# that each stage should perform, without timing information yet | |
operation_sequence = determine_1f1b_operation_sequence(num_stages, num_batches) | |
# Step 2: Convert operation sequence to schedule with timing | |
# Taking into account dependencies between operations | |
schedule = calculate_operation_timing( | |
operation_sequence, num_stages, forward_times, backward_times, p2p_time | |
) | |
return schedule | |
def determine_1f1b_operation_sequence( | |
num_stages: int, num_batches: int | |
) -> Dict[int, List[Dict]]: | |
""" | |
Determine the sequence of operations (forward/backward) for each stage in 1F1B scheduling. | |
Args: | |
num_stages: Number of pipeline stages | |
num_batches: Number of micro-batches | |
Returns: | |
Dictionary mapping stage ID to a list of operations in sequence. | |
Each operation is a dict with keys 'type' ('forward' or 'backward') and 'batch'. | |
""" | |
operation_sequence = {i: [] for i in range(num_stages)} | |
for current_stage in range(num_stages): | |
warmup_batches = num_stages - current_stage | |
for j in range(1, warmup_batches + 1): | |
operation_sequence[current_stage].append({"type": "forward", "batch": j}) | |
steady_batches = num_batches - warmup_batches | |
for j in range(warmup_batches + 1, warmup_batches + steady_batches + 1): | |
operation_sequence[current_stage].append( | |
{"type": "backward", "batch": j - warmup_batches} | |
) | |
operation_sequence[current_stage].append({"type": "forward", "batch": j}) | |
for j in range(warmup_batches): | |
operation_sequence[current_stage].append( | |
{"type": "backward", "batch": j + steady_batches + 1} | |
) | |
return operation_sequence | |
def calculate_operation_timing( | |
operation_sequence: Dict[int, List[Dict]], | |
num_stages: int, | |
forward_times: List[float], | |
backward_times: List[float], | |
p2p_time: float = 0.0, | |
) -> Dict[int, List[Dict]]: | |
""" | |
Recursively calculate the specific timing of each operation in a 1F1B schedule. | |
When encountering an operation that depends on a previous operation that hasn't been calculated yet, | |
it will recursively calculate the timing of those operations. | |
Args: | |
operation_sequence: Operation sequence for each stage | |
num_stages: Number of pipeline stages | |
forward_times: Forward propagation time for each stage | |
backward_times: Backward propagation time for each stage | |
p2p_time: Point-to-point communication time between stages | |
Returns: | |
Complete schedule with timing information, each operation includes start_time and duration | |
""" | |
# Initialize schedule with timing information | |
schedule = {i: [] for i in range(num_stages)} | |
# For recording already computed operation end times | |
# Format: {(stage, batch, op_type): (start_time, end_time)} | |
computed_ops = {} | |
# For recording the end time of the last operation for each stage | |
stage_last_end_time = [0.0] * num_stages | |
# Helper function: recursively calculate the time for an operation | |
def compute_op_time(stage, batch, op_type): | |
# Check if this operation has already been calculated | |
key = (stage, batch, op_type) | |
if key in computed_ops: | |
return computed_ops[key] | |
# Get operation duration | |
duration = ( | |
forward_times[stage] if op_type == "forward" else backward_times[stage] | |
) | |
# Determine start time (dependent on other operations) | |
# 1. Consider sequential dependencies on the stage (must wait for previous operation to complete) | |
start_time = stage_last_end_time[stage] | |
# 2. Forward pass also depends on forward pass of previous stage (if not the first stage) | |
if op_type == "forward" and stage > 0: | |
# Recursively calculate the time for the forward pass of the previous stage (if not calculated yet) | |
prev_stage_key = (stage - 1, batch, "forward") | |
if prev_stage_key not in computed_ops: | |
prev_start, prev_end = compute_op_time(stage - 1, batch, "forward") | |
else: | |
_, prev_end = computed_ops[prev_stage_key] | |
# Update start time | |
start_time = max(start_time, prev_end + p2p_time) | |
# 3. Backward pass depends on: | |
elif op_type == "backward": | |
# a. Forward pass of the same stage | |
same_stage_forward_key = (stage, batch, "forward") | |
if same_stage_forward_key not in computed_ops: | |
_, forward_end = compute_op_time(stage, batch, "forward") | |
else: | |
_, forward_end = computed_ops[same_stage_forward_key] | |
start_time = max(start_time, forward_end) | |
# b. Backward pass of the next stage (if not the last stage) | |
if stage < num_stages - 1: | |
next_stage_backward_key = (stage + 1, batch, "backward") | |
if next_stage_backward_key not in computed_ops: | |
_, next_backward_end = compute_op_time(stage + 1, batch, "backward") | |
else: | |
_, next_backward_end = computed_ops[next_stage_backward_key] | |
start_time = max(start_time, next_backward_end + p2p_time) | |
# Calculate end time | |
end_time = start_time + duration | |
# Store calculation results | |
computed_ops[key] = (start_time, end_time) | |
# Update the end time of the last operation for this stage | |
stage_last_end_time[stage] = end_time | |
return start_time, end_time | |
# Calculate time for each operation in the operation_sequence | |
for i in range(len(operation_sequence[0])): | |
for stage in range(num_stages): | |
batch = operation_sequence[stage][i]["batch"] | |
op_type = operation_sequence[stage][i]["type"] | |
# Recursively calculate the time for this operation | |
start_time, _ = compute_op_time(stage, batch, op_type) | |
# Fill in scheduling information | |
op_with_timing = operation_sequence[stage][i].copy() | |
op_with_timing["start_time"] = start_time | |
op_with_timing["duration"] = ( | |
forward_times[stage] if op_type == "forward" else backward_times[stage] | |
) | |
schedule[stage].append(op_with_timing) | |
return schedule | |
def get_bubble_rate(schedule: Dict[int, List[Dict]]): | |
num_stages = len(schedule) | |
max_time = 0 | |
for device in schedule: | |
for task in schedule[device]: | |
end_time = task["start_time"] + task["duration"] | |
if end_time > max_time: | |
max_time = end_time | |
total_execution_time = max_time * num_stages | |
total_computation_time = 0 | |
device_computation_times = {} | |
for device in schedule: | |
device_computation_time = 0 | |
for task in schedule[device]: | |
device_computation_time += task["duration"] | |
device_computation_times[device] = device_computation_time | |
total_computation_time += device_computation_time | |
bubble_rate = ( | |
total_execution_time - total_computation_time | |
) / total_computation_time | |
return bubble_rate | |
def read_config_file(config_path): | |
""" | |
Read configuration from a JSON or YAML file. | |
Args: | |
config_path: Path to the config file (JSON or YAML) | |
Returns: | |
Dictionary containing configuration parameters | |
""" | |
if not os.path.exists(config_path): | |
raise FileNotFoundError(f"Config file not found: {config_path}") | |
file_ext = os.path.splitext(config_path)[1].lower() | |
try: | |
with open(config_path, "r") as f: | |
if file_ext == ".json": | |
config = json.load(f) | |
elif file_ext in (".yaml", ".yml"): | |
config = yaml.safe_load(f) | |
else: | |
raise ValueError( | |
f"Unsupported config file format: {file_ext}. Use .json, .yaml, or .yml" | |
) | |
return config | |
except Exception as e: | |
raise ValueError(f"Error reading config file: {str(e)}") | |
def parse_args(): | |
""" | |
Parse command-line arguments for the pipeline parallelism tool. | |
Returns: | |
Parsed arguments namespace | |
""" | |
parser = argparse.ArgumentParser( | |
description="Pipeline Parallelism Scheduler and Visualizer", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
# Config file option | |
parser.add_argument( | |
"--config", "-c", type=str, help="Path to config file (JSON or YAML)" | |
) | |
# Main parameters | |
parser.add_argument( | |
"--num-stages", | |
"-s", | |
type=int, | |
default=4, | |
help="Number of pipeline stages (devices)", | |
) | |
parser.add_argument( | |
"--num-batches", "-b", type=int, default=10, help="Number of micro-batches" | |
) | |
# Forward and backward times | |
parser.add_argument( | |
"--forward-times", | |
"-f", | |
type=float, | |
nargs="+", | |
help="Time for forward pass at each stage (space-separated list)", | |
) | |
parser.add_argument( | |
"--backward-times", | |
"-bw", | |
type=float, | |
nargs="+", | |
help="Time for backward pass at each stage (space-separated list)", | |
) | |
# Output options | |
parser.add_argument( | |
"--output", | |
"-o", | |
type=str, | |
default="pipeline_1f1b.png", | |
help="Output file path for visualization", | |
) | |
parser.add_argument( | |
"--no-visualization", action="store_true", help="Skip visualization generation" | |
) | |
parser.add_argument( | |
"--p2p-time", | |
type=float, | |
default=0.0, | |
help="Time for point-to-point communication between stages", | |
) | |
return parser.parse_args() | |
def example_usage(): | |
"""Example usage of the visualization function and testing the scheduling algorithms.""" | |
# Example parameters | |
num_stages = 4 # Number of pipeline stages (devices) | |
num_batches = 10 # Number of micro-batches | |
# Example times for forward and backward passes for each stage | |
forward_times = [1.0, 1.0, 1.0, 1.0] # Time for forward pass at each stage | |
backward_times = [2.0, 2.0, 2.0, 2.0] # Time for backward pass at each stage | |
# Create 1F1B schedule | |
schedule = create_1f1b_schedule( | |
num_stages=num_stages, | |
num_batches=num_batches, | |
forward_times=forward_times, | |
backward_times=backward_times, | |
) | |
# Create visualization with the schedule | |
visualize_pipeline_parallelism( | |
schedule=schedule, schedule_type="1f1b", output_file="pipeline_1f1b.png" | |
) | |
# Analyze the schedule | |
bubble_rate = get_bubble_rate(schedule) | |
print(f"Bubble rate: {bubble_rate:.4f}") | |
def main(): | |
""" | |
Main function that parses arguments and runs the pipeline parallelism analysis. | |
""" | |
args = parse_args() | |
# Initialize with default values | |
num_stages = 4 | |
num_batches = 10 | |
forward_times = None | |
backward_times = None | |
output_file = "pipeline_1f1b.png" | |
p2p_time = 0.0 | |
# Read from config file if provided | |
if args.config: | |
try: | |
print(f"Reading configuration from {args.config}") | |
config = read_config_file(args.config) | |
# Update parameters from config | |
num_stages = config.get("num_stages", num_stages) | |
num_batches = config.get("num_batches", num_batches) | |
forward_times = config.get("forward_times") | |
backward_times = config.get("backward_times") | |
output_file = config.get("output_file", output_file) | |
p2p_time = config.get("p2p_time", 0.0) | |
except Exception as e: | |
print(f"Error reading config file: {str(e)}") | |
print("Falling back to command line arguments or defaults") | |
# Command line arguments override config file | |
if args.num_stages: | |
num_stages = args.num_stages | |
if args.num_batches: | |
num_batches = args.num_batches | |
if args.forward_times: | |
forward_times = args.forward_times | |
if args.backward_times: | |
backward_times = args.backward_times | |
if args.output: | |
output_file = args.output | |
if args.p2p_time: | |
p2p_time = args.p2p_time | |
# Validate inputs | |
if forward_times is None: | |
forward_times = [1.0] * num_stages | |
elif len(forward_times) != num_stages: | |
print( | |
f"Warning: forward_times length ({len(forward_times)}) doesn't match num_stages ({num_stages})" | |
) | |
if len(forward_times) < num_stages: | |
# Extend with repeats of the last value | |
forward_times = list(forward_times) + [forward_times[-1]] * ( | |
num_stages - len(forward_times) | |
) | |
else: | |
# Truncate | |
forward_times = forward_times[:num_stages] | |
print(f"Adjusted forward_times: {forward_times}") | |
if backward_times is None: | |
backward_times = [2.0] * num_stages | |
elif len(backward_times) != num_stages: | |
print( | |
f"Warning: backward_times length ({len(backward_times)}) doesn't match num_stages ({num_stages})" | |
) | |
if len(backward_times) < num_stages: | |
# Extend with repeats of the last value | |
backward_times = list(backward_times) + [backward_times[-1]] * ( | |
num_stages - len(backward_times) | |
) | |
else: | |
# Truncate | |
backward_times = backward_times[:num_stages] | |
print(f"Adjusted backward_times: {backward_times}") | |
print(f"Running with parameters:") | |
print(f" num_stages: {num_stages}") | |
print(f" num_batches: {num_batches}") | |
print(f" forward_times: {forward_times}") | |
print(f" backward_times: {backward_times}") | |
print(f" output_file: {output_file}") | |
# Create 1F1B schedule | |
schedule = create_1f1b_schedule( | |
num_stages=num_stages, | |
num_batches=num_batches, | |
forward_times=forward_times, | |
backward_times=backward_times, | |
p2p_time=p2p_time, | |
) | |
# Create visualization unless --no-visualization is specified | |
if not args.no_visualization: | |
visualize_pipeline_parallelism( | |
schedule=schedule, schedule_type="1f1b", output_file=output_file | |
) | |
# Analyze the schedule | |
bubble_rate = get_bubble_rate(schedule) | |
print(f"Bubble rate: {bubble_rate:.4f}") | |
return { | |
"schedule": schedule, | |
"bubble_rate": bubble_rate, | |
"num_stages": num_stages, | |
"num_batches": num_batches, | |
} | |
if __name__ == "__main__": | |
main() | |