|
""" Distributed training/validation utils |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import torch |
|
from torch import distributed as dist |
|
|
|
from .model import unwrap_model |
|
|
|
|
|
def reduce_tensor(tensor, n): |
|
rt = tensor.clone() |
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
rt /= n |
|
return rt |
|
|
|
|
|
def distribute_bn(model, world_size, reduce=False): |
|
|
|
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): |
|
if ('running_mean' in bn_name) or ('running_var' in bn_name): |
|
if reduce: |
|
|
|
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) |
|
bn_buf /= float(world_size) |
|
else: |
|
|
|
torch.distributed.broadcast(bn_buf, 0) |
|
|