NSAQA / detectron2 /projects /MViTv2 /configs /mask_rcnn_mvitv2_t_3x.py
laurenok24's picture
Upload 715 files
07d7c23 verified
raw
history blame
1.71 kB
from functools import partial
import torch.nn as nn
from fvcore.common.param_scheduler import MultiStepParamScheduler
from detectron2 import model_zoo
from detectron2.config import LazyCall as L
from detectron2.solver import WarmupParamScheduler
from detectron2.modeling import MViT
from .common.coco_loader import dataloader
model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
constants = model_zoo.get_config("common/data/constants.py").constants
model.pixel_mean = constants.imagenet_rgb256_mean
model.pixel_std = constants.imagenet_rgb256_std
model.input_format = "RGB"
model.backbone.bottom_up = L(MViT)(
embed_dim=96,
depth=10,
num_heads=1,
last_block_indexes=(0, 2, 7, 9),
residual_pooling=True,
drop_path_rate=0.2,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
out_features=("scale2", "scale3", "scale4", "scale5"),
)
model.backbone.in_features = "${.bottom_up.out_features}"
# Initialization and trainer settings
train = model_zoo.get_config("common/train.py").train
train.amp.enabled = True
train.ddp.fp16_compression = True
train.init_checkpoint = "detectron2://ImageNetPretrained/mvitv2/MViTv2_T_in1k.pyth"
dataloader.train.total_batch_size = 64
# 36 epochs
train.max_iter = 67500
lr_multiplier = L(WarmupParamScheduler)(
scheduler=L(MultiStepParamScheduler)(
values=[1.0, 0.1, 0.01],
milestones=[52500, 62500, 67500],
),
warmup_length=250 / train.max_iter,
warmup_factor=0.001,
)
optimizer = model_zoo.get_config("common/optim.py").AdamW
optimizer.params.overrides = {
"pos_embed": {"weight_decay": 0.0},
"rel_pos_h": {"weight_decay": 0.0},
"rel_pos_w": {"weight_decay": 0.0},
}
optimizer.lr = 1.6e-4