|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from pathlib import Path |
|
import urllib.request |
|
import torch |
|
|
|
from .modeling import ( |
|
ImageEncoderViT, |
|
MaskDecoder, |
|
PromptEncoder, |
|
Sam, |
|
TwoWayTransformer, |
|
) |
|
import numpy as np |
|
from .modeling.image_encoder_swin import SwinTransformer |
|
from monai.networks.nets import ViT |
|
from monai.networks.nets.swin_unetr import SwinTransformer as SwinViT |
|
|
|
from monai.utils import ensure_tuple_rep, optional_import |
|
|
|
|
|
""" |
|
Examples:: |
|
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48. |
|
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48) |
|
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. |
|
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) |
|
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. |
|
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) |
|
""" |
|
|
|
def build_sam_vit_3d(checkpoint=None): |
|
print('build_sam_vit_3d...') |
|
return _build_sam( |
|
image_encoder_type='vit', |
|
embed_dim = 768, |
|
patch_size=[4,16,16], |
|
checkpoint=checkpoint, |
|
image_size=[32,256,256], |
|
) |
|
|
|
sam_model_registry = { |
|
"vit": build_sam_vit_3d, |
|
} |
|
|
|
|
|
def _build_sam( |
|
image_encoder_type, |
|
embed_dim, |
|
patch_size, |
|
checkpoint, |
|
image_size, |
|
): |
|
mlp_dim = 3072 |
|
num_layers = 12 |
|
num_heads = 12 |
|
pos_embed = 'perceptron' |
|
dropout_rate = 0.0 |
|
|
|
image_encoder=ViT( |
|
in_channels=1, |
|
img_size=image_size, |
|
patch_size=patch_size, |
|
hidden_size=embed_dim, |
|
mlp_dim=mlp_dim, |
|
num_layers=num_layers, |
|
num_heads=num_heads, |
|
pos_embed=pos_embed, |
|
classification=False, |
|
dropout_rate=dropout_rate, |
|
) |
|
image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))] |
|
|
|
if checkpoint is not None: |
|
with open(checkpoint, "rb") as f: |
|
state_dict = torch.load(f, map_location='cpu')['state_dict'] |
|
encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k} |
|
image_encoder.load_state_dict(encoder_dict) |
|
print(f'===> image_encoder.load_param: {checkpoint}') |
|
sam = Sam( |
|
image_encoder=image_encoder, |
|
prompt_encoder=PromptEncoder( |
|
embed_dim=embed_dim, |
|
image_embedding_size=image_embedding_size, |
|
input_image_size=image_size, |
|
mask_in_chans=16, |
|
), |
|
mask_decoder=MaskDecoder( |
|
image_encoder_type=image_encoder_type, |
|
num_multimask_outputs=3, |
|
transformer=TwoWayTransformer( |
|
depth=2, |
|
embedding_dim=embed_dim, |
|
mlp_dim=2048, |
|
num_heads=8, |
|
), |
|
transformer_dim=embed_dim, |
|
iou_head_depth=3, |
|
iou_head_hidden_dim=256, |
|
image_size=np.array(image_size), |
|
patch_size=np.array(patch_size), |
|
), |
|
pixel_mean=[123.675, 116.28, 103.53], |
|
pixel_std=[58.395, 57.12, 57.375], |
|
) |
|
sam.eval() |
|
return sam |
|
|