|
from typing import Optional |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.distributed import ProcessGroup |
|
|
|
|
|
|
|
|
|
|
|
if "all_gather_into_tensor" not in dir(torch.distributed): |
|
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base |
|
if "reduce_scatter_tensor" not in dir(torch.distributed): |
|
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base |
|
|
|
|
|
|
|
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
|
world_size = torch.distributed.get_world_size(process_group) |
|
output = torch.empty( |
|
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device |
|
) |
|
handle = torch.distributed.all_gather_into_tensor( |
|
output, input_.contiguous(), group=process_group, async_op=async_op |
|
) |
|
return output, handle |
|
|
|
|
|
|
|
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
|
world_size = torch.distributed.get_world_size(process_group) |
|
assert input_.shape[0] % world_size == 0 |
|
output = torch.empty( |
|
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device |
|
) |
|
handle = torch.distributed.reduce_scatter_tensor( |
|
output, input_.contiguous(), group=process_group, async_op=async_op |
|
) |
|
return output, handle |
|
|
|
|
|
|
|
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): |
|
input_ = input_.contiguous() |
|
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) |
|
return input_, handle |
|
|
|
|
|
class AllGatherFunc(torch.autograd.Function): |
|
"""Gather the input from sequence parallel region and concatenate.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
|
ctx.process_group = process_group |
|
output, _ = all_gather_raw(input_, process_group) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output: Tensor): |
|
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) |
|
return grad_input, None |
|
|
|
|
|
|
|
all_gather = AllGatherFunc.apply |
|
|
|
|
|
class ReduceScatterFunc(torch.autograd.Function): |
|
"""Reduce scatter the input from the sequence parallel region and concatenate.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
|
ctx.process_group = process_group |
|
output, _ = reduce_scatter_raw(input_, process_group) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output: Tensor): |
|
grad_input, _ = all_gather_raw(grad_output, ctx.process_group) |
|
return grad_input, None |
|
|
|
|
|
|
|
reduce_scatter = ReduceScatterFunc.apply |
|
|
|
|
|
class AllReduceFunc(torch.autograd.Function): |
|
"""Gather the input from sequence parallel region and concatenate.""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: |
|
ctx.process_group = process_group |
|
output, _ = all_reduce_raw(input_, process_group) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output: Tensor): |
|
return grad_output, None |
|
|
|
|
|
|
|
all_reduce = AllReduceFunc.apply |
|
|
|
|
|
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): |
|
|
|
|
|
pamams_shared = { |
|
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) |
|
} |
|
for _, p in sorted(pamams_shared.items()): |
|
with torch.no_grad(): |
|
|
|
torch.distributed.broadcast( |
|
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group |
|
) |
|
|
|
|
|
|
|
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): |
|
|
|
|
|
params_seqparallel = { |
|
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) |
|
} |
|
grads = [p.grad for _, p in sorted(params_seqparallel.items())] |
|
if grads: |
|
with torch.no_grad(): |
|
coalesced = torch._utils._flatten_dense_tensors(grads) |
|
torch.distributed.all_reduce(coalesced, group=process_group) |
|
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): |
|
buf.copy_(synced) |
|
|
|
|
|
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: |
|
"""Get the dim for the local rank derived from splitting dim on world_size processes. |
|
|
|
The split may not be even across the world_size processes. |
|
""" |
|
multiple = dim // multiple_of |
|
div = multiple // world_size |
|
mod = multiple % world_size |
|
local_multiple = div + int(local_rank < mod) |
|
return local_multiple * multiple_of |
|
|