import matplotlib.pyplot as plt import numpy as np import argparse import json import yaml import os from matplotlib.patches import Rectangle from typing import List, Tuple, Dict, Literal # Import visualization function from the new module from visualizer import visualize_pipeline_parallelism def create_1f1b_schedule( num_stages: int, num_batches: int, forward_times: List[float], backward_times: List[float], p2p_time: float = 0.0, ) -> Dict[int, List[Dict]]: """ Create a 1F1B (One-Forward-One-Backward) schedule for pipeline parallelism. This implementation takes a data-centric approach: 1. First determine the operation sequence for each pipeline stage (which microbatch to process when) 2. Then calculate timing based on dependencies between operations The 1F1B pattern has three phases: - Warmup: Forward passes for first num_stages microbatches - Steady state: Alternating between forward and backward passes - Cooldown: Backward passes for remaining microbatches Returns: A dictionary mapping device IDs to lists of tasks. Each task is a dictionary with keys: - 'type': 'forward' or 'backward' - 'batch': batch number - 'start_time': start time of the task - 'duration': duration of the task """ # Initialize empty schedule schedule = {stage: [] for stage in range(num_stages)} # Step 1: Determine operation sequence for each stage # This will generate the sequence of operations (forward/backward on which microbatch) # that each stage should perform, without timing information yet operation_sequence = determine_1f1b_operation_sequence(num_stages, num_batches) # Step 2: Convert operation sequence to schedule with timing # Taking into account dependencies between operations schedule = calculate_operation_timing( operation_sequence, num_stages, forward_times, backward_times, p2p_time ) return schedule def determine_1f1b_operation_sequence( num_stages: int, num_batches: int ) -> Dict[int, List[Dict]]: """ Determine the sequence of operations (forward/backward) for each stage in 1F1B scheduling. Args: num_stages: Number of pipeline stages num_batches: Number of micro-batches Returns: Dictionary mapping stage ID to a list of operations in sequence. Each operation is a dict with keys 'type' ('forward' or 'backward') and 'batch'. """ operation_sequence = {i: [] for i in range(num_stages)} for current_stage in range(num_stages): warmup_batches = num_stages - current_stage for j in range(1, warmup_batches + 1): operation_sequence[current_stage].append({"type": "forward", "batch": j}) steady_batches = num_batches - warmup_batches for j in range(warmup_batches + 1, warmup_batches + steady_batches + 1): operation_sequence[current_stage].append( {"type": "backward", "batch": j - warmup_batches} ) operation_sequence[current_stage].append({"type": "forward", "batch": j}) for j in range(warmup_batches): operation_sequence[current_stage].append( {"type": "backward", "batch": j + steady_batches + 1} ) return operation_sequence def calculate_operation_timing( operation_sequence: Dict[int, List[Dict]], num_stages: int, forward_times: List[float], backward_times: List[float], p2p_time: float = 0.0, ) -> Dict[int, List[Dict]]: """ Recursively calculate the specific timing of each operation in a 1F1B schedule. When encountering an operation that depends on a previous operation that hasn't been calculated yet, it will recursively calculate the timing of those operations. Args: operation_sequence: Operation sequence for each stage num_stages: Number of pipeline stages forward_times: Forward propagation time for each stage backward_times: Backward propagation time for each stage p2p_time: Point-to-point communication time between stages Returns: Complete schedule with timing information, each operation includes start_time and duration """ # Initialize schedule with timing information schedule = {i: [] for i in range(num_stages)} # For recording already computed operation end times # Format: {(stage, batch, op_type): (start_time, end_time)} computed_ops = {} # For recording the end time of the last operation for each stage stage_last_end_time = [0.0] * num_stages # Helper function: recursively calculate the time for an operation def compute_op_time(stage, batch, op_type): # Check if this operation has already been calculated key = (stage, batch, op_type) if key in computed_ops: return computed_ops[key] # Get operation duration duration = ( forward_times[stage] if op_type == "forward" else backward_times[stage] ) # Determine start time (dependent on other operations) # 1. Consider sequential dependencies on the stage (must wait for previous operation to complete) start_time = stage_last_end_time[stage] # 2. Forward pass also depends on forward pass of previous stage (if not the first stage) if op_type == "forward" and stage > 0: # Recursively calculate the time for the forward pass of the previous stage (if not calculated yet) prev_stage_key = (stage - 1, batch, "forward") if prev_stage_key not in computed_ops: prev_start, prev_end = compute_op_time(stage - 1, batch, "forward") else: _, prev_end = computed_ops[prev_stage_key] # Update start time start_time = max(start_time, prev_end + p2p_time) # 3. Backward pass depends on: elif op_type == "backward": # a. Forward pass of the same stage same_stage_forward_key = (stage, batch, "forward") if same_stage_forward_key not in computed_ops: _, forward_end = compute_op_time(stage, batch, "forward") else: _, forward_end = computed_ops[same_stage_forward_key] start_time = max(start_time, forward_end) # b. Backward pass of the next stage (if not the last stage) if stage < num_stages - 1: next_stage_backward_key = (stage + 1, batch, "backward") if next_stage_backward_key not in computed_ops: _, next_backward_end = compute_op_time(stage + 1, batch, "backward") else: _, next_backward_end = computed_ops[next_stage_backward_key] start_time = max(start_time, next_backward_end + p2p_time) # Calculate end time end_time = start_time + duration # Store calculation results computed_ops[key] = (start_time, end_time) # Update the end time of the last operation for this stage stage_last_end_time[stage] = end_time return start_time, end_time # Calculate time for each operation in the operation_sequence for i in range(len(operation_sequence[0])): for stage in range(num_stages): batch = operation_sequence[stage][i]["batch"] op_type = operation_sequence[stage][i]["type"] # Recursively calculate the time for this operation start_time, _ = compute_op_time(stage, batch, op_type) # Fill in scheduling information op_with_timing = operation_sequence[stage][i].copy() op_with_timing["start_time"] = start_time op_with_timing["duration"] = ( forward_times[stage] if op_type == "forward" else backward_times[stage] ) schedule[stage].append(op_with_timing) return schedule def get_bubble_rate(schedule: Dict[int, List[Dict]]): num_stages = len(schedule) max_time = 0 for device in schedule: for task in schedule[device]: end_time = task["start_time"] + task["duration"] if end_time > max_time: max_time = end_time total_execution_time = max_time * num_stages total_computation_time = 0 device_computation_times = {} for device in schedule: device_computation_time = 0 for task in schedule[device]: device_computation_time += task["duration"] device_computation_times[device] = device_computation_time total_computation_time += device_computation_time bubble_rate = ( total_execution_time - total_computation_time ) / total_computation_time return bubble_rate def read_config_file(config_path): """ Read configuration from a JSON or YAML file. Args: config_path: Path to the config file (JSON or YAML) Returns: Dictionary containing configuration parameters """ if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}") file_ext = os.path.splitext(config_path)[1].lower() try: with open(config_path, "r") as f: if file_ext == ".json": config = json.load(f) elif file_ext in (".yaml", ".yml"): config = yaml.safe_load(f) else: raise ValueError( f"Unsupported config file format: {file_ext}. Use .json, .yaml, or .yml" ) return config except Exception as e: raise ValueError(f"Error reading config file: {str(e)}") def parse_args(): """ Parse command-line arguments for the pipeline parallelism tool. Returns: Parsed arguments namespace """ parser = argparse.ArgumentParser( description="Pipeline Parallelism Scheduler and Visualizer", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Config file option parser.add_argument( "--config", "-c", type=str, help="Path to config file (JSON or YAML)" ) # Main parameters parser.add_argument( "--num-stages", "-s", type=int, default=4, help="Number of pipeline stages (devices)", ) parser.add_argument( "--num-batches", "-b", type=int, default=10, help="Number of micro-batches" ) # Forward and backward times parser.add_argument( "--forward-times", "-f", type=float, nargs="+", help="Time for forward pass at each stage (space-separated list)", ) parser.add_argument( "--backward-times", "-bw", type=float, nargs="+", help="Time for backward pass at each stage (space-separated list)", ) # Output options parser.add_argument( "--output", "-o", type=str, default="pipeline_1f1b.png", help="Output file path for visualization", ) parser.add_argument( "--no-visualization", action="store_true", help="Skip visualization generation" ) parser.add_argument( "--p2p-time", type=float, default=0.0, help="Time for point-to-point communication between stages", ) return parser.parse_args() def example_usage(): """Example usage of the visualization function and testing the scheduling algorithms.""" # Example parameters num_stages = 4 # Number of pipeline stages (devices) num_batches = 10 # Number of micro-batches # Example times for forward and backward passes for each stage forward_times = [1.0, 1.0, 1.0, 1.0] # Time for forward pass at each stage backward_times = [2.0, 2.0, 2.0, 2.0] # Time for backward pass at each stage # Create 1F1B schedule schedule = create_1f1b_schedule( num_stages=num_stages, num_batches=num_batches, forward_times=forward_times, backward_times=backward_times, ) # Create visualization with the schedule visualize_pipeline_parallelism( schedule=schedule, schedule_type="1f1b", output_file="pipeline_1f1b.png" ) # Analyze the schedule bubble_rate = get_bubble_rate(schedule) print(f"Bubble rate: {bubble_rate:.4f}") def main(): """ Main function that parses arguments and runs the pipeline parallelism analysis. """ args = parse_args() # Initialize with default values num_stages = 4 num_batches = 10 forward_times = None backward_times = None output_file = "pipeline_1f1b.png" p2p_time = 0.0 # Read from config file if provided if args.config: try: print(f"Reading configuration from {args.config}") config = read_config_file(args.config) # Update parameters from config num_stages = config.get("num_stages", num_stages) num_batches = config.get("num_batches", num_batches) forward_times = config.get("forward_times") backward_times = config.get("backward_times") output_file = config.get("output_file", output_file) p2p_time = config.get("p2p_time", 0.0) except Exception as e: print(f"Error reading config file: {str(e)}") print("Falling back to command line arguments or defaults") # Command line arguments override config file if args.num_stages: num_stages = args.num_stages if args.num_batches: num_batches = args.num_batches if args.forward_times: forward_times = args.forward_times if args.backward_times: backward_times = args.backward_times if args.output: output_file = args.output if args.p2p_time: p2p_time = args.p2p_time # Validate inputs if forward_times is None: forward_times = [1.0] * num_stages elif len(forward_times) != num_stages: print( f"Warning: forward_times length ({len(forward_times)}) doesn't match num_stages ({num_stages})" ) if len(forward_times) < num_stages: # Extend with repeats of the last value forward_times = list(forward_times) + [forward_times[-1]] * ( num_stages - len(forward_times) ) else: # Truncate forward_times = forward_times[:num_stages] print(f"Adjusted forward_times: {forward_times}") if backward_times is None: backward_times = [2.0] * num_stages elif len(backward_times) != num_stages: print( f"Warning: backward_times length ({len(backward_times)}) doesn't match num_stages ({num_stages})" ) if len(backward_times) < num_stages: # Extend with repeats of the last value backward_times = list(backward_times) + [backward_times[-1]] * ( num_stages - len(backward_times) ) else: # Truncate backward_times = backward_times[:num_stages] print(f"Adjusted backward_times: {backward_times}") print(f"Running with parameters:") print(f" num_stages: {num_stages}") print(f" num_batches: {num_batches}") print(f" forward_times: {forward_times}") print(f" backward_times: {backward_times}") print(f" output_file: {output_file}") # Create 1F1B schedule schedule = create_1f1b_schedule( num_stages=num_stages, num_batches=num_batches, forward_times=forward_times, backward_times=backward_times, p2p_time=p2p_time, ) # Create visualization unless --no-visualization is specified if not args.no_visualization: visualize_pipeline_parallelism( schedule=schedule, schedule_type="1f1b", output_file=output_file ) # Analyze the schedule bubble_rate = get_bubble_rate(schedule) print(f"Bubble rate: {bubble_rate:.4f}") return { "schedule": schedule, "bubble_rate": bubble_rate, "num_stages": num_stages, "num_batches": num_batches, } if __name__ == "__main__": main()