Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Distributed Training Utilities | |
This file contains utility functions for distributed training with PyTorch. | |
It provides tools for setting up distributed environments, efficient file handling | |
across processes, model unwrapping, and synchronization mechanisms to coordinate | |
execution across multiple GPUs and nodes. | |
""" | |
import os | |
import io | |
from contextlib import contextmanager | |
import torch | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
def setup_dist(rank, local_rank, world_size, master_addr, master_port): | |
""" | |
Set up the distributed training environment. | |
Args: | |
rank (int): Global rank of the current process | |
local_rank (int): Local rank of the current process on this node | |
world_size (int): Total number of processes in the distributed training | |
master_addr (str): IP address of the master node | |
master_port (str): Port on the master node for communication | |
""" | |
os.environ['MASTER_ADDR'] = master_addr | |
os.environ['MASTER_PORT'] = master_port | |
os.environ['WORLD_SIZE'] = str(world_size) | |
os.environ['RANK'] = str(rank) | |
os.environ['LOCAL_RANK'] = str(local_rank) | |
# Set the device for the current process | |
torch.cuda.set_device(local_rank) | |
# Initialize the process group for distributed communication | |
dist.init_process_group('nccl', rank=rank, world_size=world_size) | |
def read_file_dist(path): | |
""" | |
Read the binary file distributedly. | |
File is only read once by the rank 0 process and broadcasted to other processes. | |
This reduces I/O overhead in distributed training. | |
Args: | |
path (str): Path to the file to be read | |
Returns: | |
data (io.BytesIO): The binary data read from the file. | |
""" | |
if dist.is_initialized() and dist.get_world_size() > 1: | |
# Prepare tensor to store file size | |
size = torch.LongTensor(1).cuda() | |
if dist.get_rank() == 0: | |
# Master process reads the file | |
with open(path, 'rb') as f: | |
data = f.read() | |
# Convert binary data to CUDA tensor for broadcasting | |
data = torch.ByteTensor( | |
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) | |
).cuda() | |
size[0] = data.shape[0] | |
# Broadcast file size to all processes | |
dist.broadcast(size, src=0) | |
if dist.get_rank() != 0: | |
# Non-master processes allocate buffer for receiving data | |
data = torch.ByteTensor(size[0].item()).cuda() | |
# Broadcast actual file data to all processes | |
dist.broadcast(data, src=0) | |
# Convert tensor back to binary data | |
data = data.cpu().numpy().tobytes() | |
data = io.BytesIO(data) | |
return data | |
else: | |
# For non-distributed or single-process case, just read directly | |
with open(path, 'rb') as f: | |
data = f.read() | |
data = io.BytesIO(data) | |
return data | |
def unwrap_dist(model): | |
""" | |
Unwrap the model from distributed training wrapper. | |
Args: | |
model: A potentially wrapped PyTorch model | |
Returns: | |
The underlying model without DistributedDataParallel wrapper | |
""" | |
if isinstance(model, DDP): | |
return model.module | |
return model | |
def master_first(): | |
""" | |
A context manager that ensures master process (rank 0) executes first. | |
All other processes wait for the master to finish before proceeding. | |
Usage: | |
with master_first(): | |
# Code that should execute in master first, then others | |
""" | |
if not dist.is_initialized(): | |
# If not in distributed mode, just execute normally | |
yield | |
else: | |
if dist.get_rank() == 0: | |
# Master process executes the code | |
yield | |
# Signal completion to other processes | |
dist.barrier() | |
else: | |
# Other processes wait for master to finish | |
dist.barrier() | |
# Then execute the code | |
yield | |
def local_master_first(): | |
""" | |
A context manager that ensures local master process (first process on each node) | |
executes first. Other processes on the same node wait before proceeding. | |
Usage: | |
with local_master_first(): | |
# Code that should execute in local master first, then others | |
""" | |
if not dist.is_initialized(): | |
# If not in distributed mode, just execute normally | |
yield | |
else: | |
if dist.get_rank() % torch.cuda.device_count() == 0: | |
# Local master process executes the code | |
yield | |
# Signal completion to other processes | |
dist.barrier() | |
else: | |
# Other processes wait for local master to finish | |
dist.barrier() | |
# Then execute the code | |
yield |