|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import timm |
|
from accelerate.logging import get_logger |
|
|
|
logger = get_logger(__name__) |
|
|
|
class XUNet(nn.Module): |
|
def __init__(self, model_name="swin_base_patch4_window12_384_in22k", encoder_feat_dim=384): |
|
super(XUNet, self).__init__() |
|
|
|
self.encoder = timm.create_model(model_name, pretrained=True) |
|
|
|
|
|
|
|
|
|
del self.encoder.global_pool |
|
del self.encoder.fc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.upconv4 = self.upconv_block(512, 256) |
|
self.upconv3 = self.upconv_block(256, 128) |
|
self.upconv2 = self.upconv_block(128, 64) |
|
|
|
|
|
self.out_conv = nn.Conv2d(64, encoder_feat_dim, kernel_size=1) |
|
|
|
|
|
def upconv_block(self, in_channels, out_channels): |
|
return nn.Sequential( |
|
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
def forward(self, x): |
|
|
|
enc_output = self.encoder.forward_intermediates(x, stop_early=True, intermediates_only=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
enc_out4 = enc_output[4] |
|
enc_out3 = enc_output[3] |
|
enc_out2 = enc_output[2] |
|
enc_out1 = enc_output[1] |
|
|
|
|
|
|
|
x = self.upconv4(enc_out4) |
|
x = x + enc_out3 |
|
x = self.upconv3(x) |
|
x = x + enc_out2 |
|
x = self.upconv2(x) |
|
x = x + enc_out1 |
|
|
|
|
|
|
|
x = self.out_conv(x) |
|
return x |
|
|
|
|
|
class XnetWrapper(nn.Module): |
|
""" |
|
XnetWrapper using original implementation, hacked with modulation. |
|
""" |
|
def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True, encoder_feat_dim: int = 384): |
|
super().__init__() |
|
self.modulation_dim = modulation_dim |
|
self.model = XUNet(model_name=model_name, encoder_feat_dim=encoder_feat_dim) |
|
|
|
if freeze: |
|
if modulation_dim is not None: |
|
raise ValueError("Modulated SwinUnetWrapper requires training, freezing is not allowed.") |
|
self._freeze() |
|
|
|
def _freeze(self): |
|
logger.warning(f"======== Freezing SwinUnetWrapper ========") |
|
self.model.eval() |
|
for name, param in self.model.named_parameters(): |
|
param.requires_grad = False |
|
|
|
@torch.compile |
|
def forward(self, image: torch.Tensor, mod: torch.Tensor = None): |
|
|
|
|
|
|
|
outs = self.model(image) |
|
ret = outs.permute(0, 2, 3, 1).flatten(1, 2) |
|
return ret |
|
|