|
import math |
|
import torch |
|
import torch.distributed as dist |
|
|
|
from detectron2.modeling.roi_heads import FastRCNNConvFCHead, MaskRCNNConvUpsampleHead |
|
from detectron2.utils import comm |
|
from fvcore.nn.distributed import differentiable_all_gather |
|
|
|
|
|
def concat_all_gather(input): |
|
bs_int = input.shape[0] |
|
size_list = comm.all_gather(bs_int) |
|
max_size = max(size_list) |
|
max_shape = (max_size,) + input.shape[1:] |
|
|
|
padded_input = input.new_zeros(max_shape) |
|
padded_input[:bs_int] = input |
|
all_inputs = differentiable_all_gather(padded_input) |
|
inputs = [x[:sz] for sz, x in zip(size_list, all_inputs)] |
|
return inputs, size_list |
|
|
|
|
|
def batch_shuffle(x): |
|
|
|
batch_size_this = x.shape[0] |
|
all_xs, batch_size_all = concat_all_gather(x) |
|
all_xs_concat = torch.cat(all_xs, dim=0) |
|
total_bs = sum(batch_size_all) |
|
|
|
rank = dist.get_rank() |
|
assert batch_size_all[rank] == batch_size_this |
|
|
|
idx_range = (sum(batch_size_all[:rank]), sum(batch_size_all[: rank + 1])) |
|
|
|
|
|
idx_shuffle = torch.randperm(total_bs, device=x.device) |
|
|
|
dist.broadcast(idx_shuffle, src=0) |
|
|
|
|
|
idx_unshuffle = torch.argsort(idx_shuffle) |
|
|
|
|
|
splits = torch.split(idx_shuffle, math.ceil(total_bs / dist.get_world_size())) |
|
if len(splits) > rank: |
|
idx_this = splits[rank] |
|
else: |
|
idx_this = idx_shuffle.new_zeros([0]) |
|
return all_xs_concat[idx_this], idx_unshuffle[idx_range[0] : idx_range[1]] |
|
|
|
|
|
def batch_unshuffle(x, idx_unshuffle): |
|
all_x, _ = concat_all_gather(x) |
|
x_gather = torch.cat(all_x, dim=0) |
|
return x_gather[idx_unshuffle] |
|
|
|
|
|
def wrap_shuffle(module_type, method): |
|
def new_method(self, x): |
|
if self.training: |
|
x, idx = batch_shuffle(x) |
|
x = getattr(module_type, method)(self, x) |
|
if self.training: |
|
x = batch_unshuffle(x, idx) |
|
return x |
|
|
|
return type(module_type.__name__ + "WithShuffle", (module_type,), {method: new_method}) |
|
|
|
|
|
from .mask_rcnn_BNhead import model, dataloader, lr_multiplier, optimizer, train |
|
|
|
|
|
model.roi_heads.box_head._target_ = wrap_shuffle(FastRCNNConvFCHead, "forward") |
|
model.roi_heads.mask_head._target_ = wrap_shuffle(MaskRCNNConvUpsampleHead, "layers") |
|
|