Spaces:
Runtime error
Runtime error
| # Code borrowed from Kai Zhang https://github.com/cszn/DPIR/tree/master/models | |
| import re | |
| import math | |
| import functools | |
| import deepinv as dinv | |
| from deepinv.utils import plot, TensorList | |
| import torch | |
| from torch.func import vmap | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from deepinv.optim.utils import conjugate_gradient | |
| from physics.multiscale import MultiScaleLinearPhysics, Pad | |
| from models.blocks import EquivMaxPool, AffineConv2d, ConvNextBlock2, NoiseEmbedding, MPConv, TimestepEmbedding, conv, downsample_strideconv, upsample_convtranspose | |
| from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads | |
| cuda = True if torch.cuda.is_available() else False | |
| Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor | |
| ### --------------- MODEL --------------- | |
| class BaseEncBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| bias=False, | |
| mode="CRC", | |
| nb=2, | |
| embedding=False, | |
| emb_channels=None, | |
| emb_physics=False, | |
| img_channels=None, | |
| decode_upscale=None, | |
| config='A', | |
| N=4, | |
| c_mult=1, | |
| depth_encoding=1, | |
| relu_in_encoding=False, | |
| skip_in_encoding=True, | |
| ): | |
| super(BaseEncBlock, self).__init__() | |
| self.config = config | |
| self.enc = nn.ModuleList( | |
| [ | |
| ResBlock( | |
| in_channels, | |
| out_channels, | |
| bias=bias, | |
| mode=mode, | |
| embedding=embedding, | |
| emb_channels=emb_channels, | |
| emb_physics=emb_physics, | |
| img_channels=img_channels, | |
| decode_upscale=decode_upscale, | |
| config=config, | |
| N=N, | |
| c_mult=c_mult, | |
| depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, | |
| skip_in_encoding=skip_in_encoding, | |
| ) | |
| for _ in range(nb) | |
| ] | |
| ) | |
| def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0): | |
| for i in range(len(self.enc)): | |
| x = self.enc[i](x, emb_sigma=emb_sigma, physics=physics, t=t, y=y, img_channels=img_channels, scale=scale) | |
| return x | |
| class NextEncBlock(nn.Module): | |
| def __init__( | |
| self, in_channels, out_channels, bias=False, mode="", mult_fact=4, nb=2 | |
| ): | |
| super(NextEncBlock, self).__init__() | |
| self.enc = nn.ModuleList( | |
| [ | |
| ConvNextBlock2( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| bias=bias, | |
| mode=mode, | |
| mult_fact=mult_fact, | |
| ) | |
| for _ in range(nb) | |
| ] | |
| ) | |
| def forward(self, x, emb_sigma=None): | |
| for i in range(len(self.enc)): | |
| x = self.enc[i](x, emb_sigma) | |
| return x | |
| class UNeXt(nn.Module): | |
| r""" | |
| DRUNet denoiser network. | |
| The network architecture is based on the paper | |
| `Learning deep CNN denoiser prior for image restoration <https://arxiv.org/abs/1704.03264>`_, | |
| and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts. | |
| The network takes into account the noise level of the input image, which is encoded as an additional input channel. | |
| A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3) | |
| can be downloaded via setting ``pretrained='download'``. | |
| :param int in_channels: number of channels of the input. | |
| :param int out_channels: number of channels of the output. | |
| :param list nc: number of convolutional layers. | |
| :param int nb: number of convolutional blocks per layer. | |
| :param int nf: number of channels per convolutional layer. | |
| :param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus. | |
| :param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and | |
| "strideconv" for convolution with stride 2. | |
| :param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel | |
| shuffling, and "upconv" for nearest neighbour upsampling with additional convolution. | |
| :param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random | |
| using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an | |
| online repository (only available for the default architecture with 3 or 1 input/output channels). | |
| Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights. | |
| See :ref:`pretrained-weights <pretrained-weights>` for more details. | |
| :param bool train: training or testing mode. | |
| :param str device: gpu or cpu. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels=[1, 2, 3], | |
| out_channels=[1, 2, 3], | |
| nc=[64, 128, 256, 512], | |
| nb=4, # 4 in DRUNet but out of memory | |
| conv_type="next", # should be 'base' or 'next' | |
| pool_type="next", # should be 'base' or 'next' | |
| cond_type="base", # conditioning, should be 'base' or 'edm' | |
| device=None, | |
| bias=False, | |
| mode="", | |
| residual=False, | |
| act_mode="R", | |
| layer_scale_init_value=1e-6, | |
| init_type="ortho", | |
| gain_init_conv=1.0, | |
| gain_init_linear=1.0, | |
| drop_prob=0.0, | |
| replk=False, | |
| mult_fact=4, | |
| antialias="gaussian", | |
| emb_physics=False, | |
| config='A', | |
| pretrained_pth=None, | |
| N=4, | |
| c_mult=1, | |
| depth_encoding=1, | |
| relu_in_encoding=False, | |
| skip_in_encoding=True, | |
| ): | |
| super(UNeXt, self).__init__() | |
| self.residual = residual | |
| self.conv_type = conv_type | |
| self.pool_type = pool_type | |
| self.emb_physics = emb_physics | |
| self.config = config | |
| self.in_channels = in_channels | |
| self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device)) | |
| self.separate_head = isinstance(in_channels, list) | |
| assert cond_type in ["base", "edm"], "cond_type should be 'base' or 'edm'" | |
| self.cond_type = cond_type | |
| if self.cond_type == "base": | |
| if self.config != 'E': | |
| if isinstance(in_channels, list): | |
| in_channels_first = [] | |
| for i in range(len(in_channels)): | |
| in_channels_first.append(in_channels[i] + 2) | |
| else: # old head | |
| in_channels_first = in_channels + 1 | |
| else: | |
| in_channels_first = in_channels | |
| else: | |
| in_channels_first = in_channels | |
| self.noise_embedding = NoiseEmbedding( | |
| num_channels=in_channels, emb_channels=max(nc), device=device | |
| ) | |
| self.timestep_embedding = lambda x: x | |
| # check if in_channels is a list | |
| self.m_head = InHead(in_channels_first, nc[0]) | |
| if conv_type == "next": | |
| self.m_down1 = NextEncBlock( | |
| nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
| ) | |
| self.m_down2 = NextEncBlock( | |
| nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
| ) | |
| self.m_down3 = NextEncBlock( | |
| nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
| ) | |
| self.m_body = NextEncBlock( | |
| nc[3], nc[3], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
| ) | |
| self.m_up3 = NextEncBlock( | |
| nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
| ) | |
| self.m_up2 = NextEncBlock( | |
| nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
| ) | |
| self.m_up1 = NextEncBlock( | |
| nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
| ) | |
| elif conv_type == "base": | |
| embedding = ( | |
| False if cond_type == "base" else True | |
| ) | |
| emb_channels = max(nc) | |
| self.m_down1 = BaseEncBlock( | |
| nc[0], | |
| nc[0], | |
| bias=False, | |
| mode="CRC", | |
| nb=nb, | |
| embedding=embedding, | |
| emb_channels=emb_channels, | |
| emb_physics=emb_physics, | |
| img_channels=in_channels, | |
| decode_upscale=1, | |
| config=config, | |
| N=N, | |
| c_mult=c_mult, | |
| depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, | |
| skip_in_encoding=skip_in_encoding, | |
| ) | |
| self.m_down2 = BaseEncBlock( | |
| nc[1], | |
| nc[1], | |
| bias=False, | |
| mode="CRC", | |
| nb=nb, | |
| embedding=embedding, | |
| emb_channels=emb_channels, | |
| emb_physics=emb_physics, | |
| img_channels=in_channels, | |
| decode_upscale=2, | |
| config=config, | |
| N=N, | |
| c_mult=c_mult, | |
| depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, | |
| skip_in_encoding=skip_in_encoding, | |
| ) | |
| self.m_down3 = BaseEncBlock( | |
| nc[2], | |
| nc[2], | |
| bias=False, | |
| mode="CRC", | |
| nb=nb, | |
| embedding=embedding, | |
| emb_channels=emb_channels, | |
| emb_physics=emb_physics, | |
| img_channels=in_channels, | |
| decode_upscale=4, | |
| config=config, | |
| N=N, | |
| c_mult=c_mult, | |
| depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, | |
| skip_in_encoding=skip_in_encoding, | |
| ) | |
| self.m_body = BaseEncBlock( | |
| nc[3], | |
| nc[3], | |
| bias=False, | |
| mode="CRC", | |
| nb=nb, | |
| embedding=embedding, | |
| emb_channels=emb_channels, | |
| emb_physics=emb_physics, | |
| img_channels=in_channels, | |
| decode_upscale=8, | |
| config=config, | |
| N=N, | |
| c_mult=c_mult, | |
| depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, | |
| skip_in_encoding=skip_in_encoding, | |
| ) | |
| self.m_up3 = BaseEncBlock( | |
| nc[2], | |
| nc[2], | |
| bias=False, | |
| mode="CRC", | |
| nb=nb, | |
| embedding=embedding, | |
| emb_channels=emb_channels, | |
| emb_physics=emb_physics, | |
| img_channels=in_channels, | |
| decode_upscale=4, | |
| config=config, | |
| N=N, | |
| c_mult=c_mult, | |
| depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, | |
| skip_in_encoding=skip_in_encoding, | |
| ) | |
| self.m_up2 = BaseEncBlock( | |
| nc[1], | |
| nc[1], | |
| bias=False, | |
| mode="CRC", | |
| nb=nb, | |
| embedding=embedding, | |
| emb_channels=emb_channels, | |
| emb_physics=emb_physics, | |
| img_channels=in_channels, | |
| decode_upscale=2, | |
| config=config, | |
| N=N, | |
| c_mult=c_mult, | |
| depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, | |
| skip_in_encoding=skip_in_encoding, | |
| ) | |
| self.m_up1 = BaseEncBlock( | |
| nc[0], | |
| nc[0], | |
| bias=False, | |
| mode="CRC", | |
| nb=nb, | |
| embedding=embedding, | |
| emb_channels=emb_channels, | |
| emb_physics=emb_physics, | |
| img_channels=in_channels, | |
| decode_upscale=1, | |
| config=config, | |
| N=N, | |
| c_mult=c_mult, | |
| depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, | |
| skip_in_encoding=skip_in_encoding, | |
| ) | |
| else: | |
| raise NotImplementedError("conv_type should be 'base' or 'next'") | |
| if pool_type == "next_max": | |
| self.pool1 = EquivMaxPool( | |
| antialias=antialias, | |
| in_channels=nc[0], | |
| out_channels=nc[1], | |
| device=device, | |
| ) | |
| self.pool2 = EquivMaxPool( | |
| antialias=antialias, | |
| in_channels=nc[1], | |
| out_channels=nc[2], | |
| device=device, | |
| ) | |
| self.pool3 = EquivMaxPool( | |
| antialias=antialias, | |
| in_channels=nc[2], | |
| out_channels=nc[3], | |
| device=device, | |
| ) | |
| elif pool_type == "base": | |
| self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2") | |
| self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2") | |
| self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2") | |
| self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2") | |
| self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2") | |
| self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2") | |
| else: | |
| raise NotImplementedError("pool_type should be 'base' or 'next'") | |
| self.m_tail = OutTail(nc[0], in_channels) | |
| if conv_type == "base": | |
| init_func = functools.partial( | |
| weights_init_unext, init_type="ortho", gain_conv=0.2 | |
| ) | |
| self.apply(init_func) | |
| else: | |
| init_func = functools.partial( | |
| weights_init_unext, | |
| init_type=init_type, | |
| gain_conv=gain_init_conv, | |
| gain_linear=gain_init_linear, | |
| ) | |
| self.apply(init_func) | |
| if pretrained_pth=='jz': | |
| pth = '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth' | |
| self.load_drunet_weights(pth) | |
| elif pretrained_pth is not None: | |
| self.load_drunet_weights(pretrained_pth) | |
| if self.config == 'D': | |
| # deactivate grad for layers that do not contain the string "PhysicsBlock" or "gain" or "fact_realign" | |
| for name, param in self.named_parameters(): | |
| if 'PhysicsBlock' not in name and 'gain' not in name and 'fact_realign' not in name and "m_head" not in name and "m_tail" not in name: | |
| param.requires_grad = False | |
| if device is not None: | |
| self.to(device) | |
| def load_drunet_weights(self, ckpt_pth): | |
| state_dict = torch.load(ckpt_pth, map_location=lambda storage, loc: storage) | |
| new_state_dict = {} | |
| matched_keys = [] # List to store successfully matched keys | |
| unmatched_keys = [] # List to store keys that were not matched or excluded | |
| excluded_keys = [] # List to store excluded keys | |
| # Define patterns to exclude | |
| exclude_patterns = ["head", "tail"] | |
| # Dealing with regular keys | |
| for old_key, value in state_dict.items(): | |
| # Skip keys containing any of the excluded patterns | |
| if any(excluded in old_key for excluded in exclude_patterns): | |
| excluded_keys.append(old_key) | |
| continue # Skip further processing for this key | |
| new_key = old2new(old_key) | |
| if new_key is not None: | |
| matched_keys.append((old_key, new_key)) # Record the matched keys | |
| new_state_dict[new_key] = value | |
| else: | |
| unmatched_keys.append(old_key) # Record unmatched keys | |
| # TODO: clean this | |
| for excluded_key in excluded_keys: | |
| if isinstance(self.in_channels, list): | |
| for i, in_channel in enumerate(self.in_channels): | |
| # print('Dealing with conv ', i) | |
| new_key = f"m_head.conv{i}.weight" | |
| if 'head' in excluded_key: | |
| new_key = f"m_head.conv{i}.weight" | |
| # new_key = f"m_head.head.conv{i}.weight" | |
| if 'tail' in excluded_key: | |
| new_key = f"m_tail.conv{i}.weight" | |
| # DEBUG print all keys of state dict: | |
| # print(state_dict.keys()) | |
| # print(self.state_dict().keys()) | |
| conditioning = 'base' | |
| # if self.config == 'E': | |
| # conditioning = False | |
| new_kv = update_keyvals_headtail(excluded_key, | |
| state_dict[excluded_key], | |
| init_value=self.state_dict()[new_key], | |
| new_key_name=new_key, | |
| conditioning=conditioning) | |
| new_state_dict.update(new_kv) | |
| # print(new_kv.keys()) | |
| else: | |
| new_kv = update_keyvals_headtail(excluded_key, state_dict[excluded_key]) | |
| new_state_dict.update(new_kv) | |
| # Display matched keys | |
| print("Matched keys:") | |
| for old_key, new_key in matched_keys: | |
| print(f"{old_key} -> {new_key}") | |
| # Load updated state dict into the model | |
| self.load_state_dict(new_state_dict, strict=False) | |
| # Display unmatched keys | |
| print("\nUnmatched keys:") | |
| for unmatched_key in unmatched_keys: | |
| print(unmatched_key) | |
| print("Weights loaded from ", ckpt_pth) | |
| def constant2map(self, value, x): | |
| if isinstance(value, torch.Tensor): | |
| if value.ndim > 0: | |
| value_map = value.view(x.size(0), 1, 1, 1) | |
| value_map = value_map.expand(-1, 1, x.size(2), x.size(3)) | |
| else: | |
| value_map = torch.ones( | |
| (x.size(0), 1, x.size(2), x.size(3)), device=x.device | |
| ) * value[None, None, None, None].to(x.device) | |
| else: | |
| value_map = ( | |
| torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) | |
| * value | |
| ) | |
| return value_map | |
| def base_conditioning(self, x, sigma, gamma): | |
| noise_level_map = self.constant2map(sigma, x) | |
| gamma_map = self.constant2map(gamma, x) | |
| return torch.cat((x, noise_level_map, gamma_map), 1) | |
| def realign_input(self, x, physics, y): | |
| if hasattr(physics, "factor"): | |
| f = physics.factor | |
| elif hasattr(physics, "base") and hasattr(physics.base, "factor"): | |
| f = physics.base.factor | |
| elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"): | |
| f = physics.base.base.factor | |
| else: | |
| f = 1.0 | |
| sigma = 1e-6 # default value | |
| if hasattr(physics.noise_model, 'sigma'): | |
| sigma = physics.noise_model.sigma | |
| if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, 'sigma'): | |
| sigma = physics.base.noise_model.sigma | |
| if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, 'noise_model') and hasattr(physics.base.base.noise_model, 'sigma'): | |
| sigma = physics.base.base.noise_model.sigma | |
| if isinstance(y, TensorList): | |
| num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1)) | |
| else: | |
| num = (y.reshape(y.shape[0], -1).abs().mean(1)) | |
| snr = num / (sigma + 1e-4) # SNR equivariant | |
| gamma = 1 / (1e-4 + 1 / (snr * f **2 )) # TODO: check square-root / mean / check if we need to add a factor in front ? | |
| gamma = gamma[(...,) + (None,) * (x.dim() - 1)] | |
| model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign) | |
| return model_input | |
| def forward_unet(self, x0, sigma=None, gamma=None, physics=None, t=None, y=None, img_channels=None): | |
| # list_values = [] | |
| if self.cond_type == "base": | |
| # if self.config != 'E': | |
| x0 = self.base_conditioning(x0, sigma, gamma) | |
| emb_sigma = None | |
| else: | |
| emb_sigma = self.noise_embedding( | |
| sigma | |
| ) # This only if the embedding is the non-basic one from drunet | |
| emb_timestep = self.timestep_embedding(t) | |
| x1 = self.m_head(x0) # old | |
| # x1 = self.m_head(x0, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
| # list_values.append(x1.abs().mean()) | |
| if self.config == 'G': | |
| x1_, emb1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
| else: | |
| x1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0) | |
| x2 = self.pool1(x1_) | |
| # list_values.append(x2.abs().mean()) | |
| if self.config == 'G': | |
| x3_, emb3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
| else: | |
| x3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1) | |
| x3 = self.pool2(x3_) | |
| # list_values.append(x3.abs().mean()) | |
| if self.config == 'G': | |
| x4_, emb4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
| else: | |
| x4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2) | |
| x4 = self.pool3(x4_) | |
| # issue: https://github.com/matthieutrs/ram_project/issues/1 | |
| # solution 1: using .contiguous() below | |
| # solution 2: using a print statement that magically solves the issue | |
| ###print(x4.is_contiguous()) | |
| # list_values.append(x4.abs().mean()) | |
| if self.config == 'G': | |
| x, _ = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
| else: | |
| x = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=3) | |
| # list_values.append(x.abs().mean()) | |
| if self.pool_type == "next" or self.pool_type == "next_max": | |
| x = self.pool3.upscale(x + x4) | |
| else: | |
| x = self.up3(x + x4) | |
| if self.config == 'G': | |
| x, _ = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb4_, img_channels=img_channels) | |
| else: | |
| x = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2) | |
| # list_values.append(x.abs().mean()) | |
| if self.pool_type == "next" or self.pool_type == "next_max": | |
| x = self.pool2.upscale(x + x3) | |
| else: | |
| x = self.up2(x + x3) | |
| if self.config == 'G': | |
| x, _ = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb3_, img_channels=img_channels) | |
| else: | |
| x = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1) | |
| # list_values.append(x.abs().mean()) | |
| if self.pool_type == "next" or self.pool_type == "next_max": | |
| x = self.pool1.upscale(x + x2) | |
| else: | |
| x = self.up1(x + x2) | |
| if self.config == 'G': | |
| x, _ = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb1_, img_channels=img_channels) | |
| else: | |
| x = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0) | |
| # list_values.append(x.abs().mean()) | |
| if self.separate_head: | |
| x = self.m_tail(x + x1, img_channels) | |
| else: | |
| x = self.m_tail(x + x1) | |
| return x | |
| def forward(self, x, sigma=None, gamma=None, physics=None, t=None, y=None): | |
| r""" | |
| Run the denoiser on image with noise level :math:`\sigma`. | |
| :param torch.Tensor x: noisy image | |
| :param float, torch.Tensor sigma: noise level. If ``sigma`` is a float, it is used for all images in the batch. | |
| If ``sigma`` is a tensor, it must be of shape ``(batch_size,)``. | |
| """ | |
| img_channels = x.shape[1] # x_n_chan = x.shape[1] | |
| if self.emb_physics: | |
| physics = MultiScaleLinearPhysics(physics, x.shape[-3:], device=x.device) | |
| if self.separate_head and img_channels not in self.in_channels: | |
| raise ValueError(f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.") | |
| if y is not None: | |
| x = self.realign_input(x, physics, y) | |
| x = self.forward_unet(x, sigma=sigma, gamma=gamma, physics=physics, t=t, y=y, img_channels=img_channels) | |
| return x | |
| def krylov_embeddings_old(y, p, factor, v=None, N=4, feat_size=1, x_init=None, img_channels=3): | |
| if x_init is None: | |
| x = p.A_adjoint(y) | |
| else: | |
| x = x_init[:, :img_channels, ...] | |
| if feat_size > 1: | |
| _, C, _, _ = x.shape | |
| if v is None: | |
| v = torch.zeros_like(x).repeat(1, N-1, 1, 1) | |
| out = x - v[:, :C, ...] | |
| norm = factor ** 2 | |
| A = lambda u: p.A_adjoint(p.A(u)) * norm | |
| for i in range(N-1): | |
| x = A(x) - v[:, (i+1) * C:(i+2) * C, ...] | |
| out = torch.cat([out, x], dim=1) | |
| else: | |
| if v is None: | |
| v = torch.zeros_like(x) | |
| out = x - v | |
| norm = factor ** 2 | |
| A = lambda u: p.A_adjoint(p.A(u)) * norm | |
| for i in range(N-1): | |
| x = A(x) - v | |
| out = torch.cat([out, x], dim=1) | |
| return out | |
| def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None, img_channels=3): | |
| """ | |
| Efficient Krylov subspace embedding computation with parallel processing. | |
| Args: | |
| y (torch.Tensor): The input tensor. | |
| p: An object with A and A_adjoint methods (linear operator). | |
| factor (float): Scaling factor. | |
| v (torch.Tensor, optional): Precomputed values to subtract from Krylov sequence. Defaults to None. | |
| N (int, optional): Number of Krylov iterations. Defaults to 4. | |
| feat_size (int, optional): Feature expansion size. Defaults to 1. | |
| x_init (torch.Tensor, optional): Initial guess. Defaults to None. | |
| img_channels (int, optional): Number of image channels. Defaults to 3. | |
| Returns: | |
| torch.Tensor: The Krylov embeddings. | |
| """ | |
| if x_init is None: | |
| x = p.A_adjoint(y) | |
| else: | |
| x = x_init.clone() # Extract the first img_channels | |
| norm = factor ** 2 # Precompute normalization factor | |
| AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator | |
| v = v if v is not None else torch.zeros_like(x) | |
| out = x.clone() | |
| # Compute Krylov basis | |
| x_k = x.clone() | |
| for i in range(N-1): | |
| x_k = AtA(x_k) - v | |
| out = torch.cat([out, x_k], dim=1) | |
| return out | |
| def grad_embeddings(y, p, factor, v=None, N=4, feat_size=1): | |
| Aty = p.A_adjoint(y) | |
| if feat_size > 1: | |
| _, C, _, _ = Aty.shape | |
| if v is None: | |
| v = torch.zeros_like(Aty).repeat(1, N-1, 1, 1) | |
| out = v[:, :C, ...] - Aty | |
| norm = factor ** 2 | |
| A = lambda u: p.A_adjoint(p.A(u)) * norm | |
| for i in range(N-1): | |
| x = A(v[:, (i+1) * C:(i+2) * C, ...]) - Aty | |
| out = torch.cat([out, x], dim=1) | |
| else: | |
| if v is None: | |
| v = torch.zeros_like(Aty) | |
| out = v - Aty | |
| norm = factor ** 2 | |
| A = lambda u: p.A_adjoint(p.A(u)) * norm | |
| for i in range(N-1): | |
| x = A(v) - Aty | |
| out = torch.cat([out, x], dim=1) | |
| return out | |
| def prox_embeddings(y, p, factor, v=None, N=4): | |
| x = p.A_adjoint(y) | |
| B, C, H, W = x.shape | |
| if v is None: | |
| v = torch.zeros_like(x) | |
| v = v.repeat(1, N - 1, 1, 1) | |
| gamma = torch.logspace(-4, -1, N-1, device=x.device).repeat_interleave(C).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) | |
| norm = factor ** 2 | |
| A_sub = lambda u: torch.cat([p.A_adjoint(p.A(u[:, i * C:(i+1) * C, ...])) * norm for i in range(N-1)], dim=1) | |
| A = lambda u: A_sub(u) + (u - v) * gamma | |
| u_hat = conjugate_gradient(A, x.repeat(1, N-1, 1, 1), max_iter=3, tol=1e-3) | |
| u_hat = torch.cat([u_hat, x], dim=1) | |
| return u_hat | |
| # -------------------------------------------- | |
| # Res Block: x + conv(relu(conv(x))) | |
| # -------------------------------------------- | |
| class MeasCondBlock(nn.Module): | |
| def __init__( | |
| self, | |
| out_channels=64, | |
| img_channels=None, | |
| decode_upscale=None, | |
| config = 'A', | |
| N=4, | |
| depth_encoding=1, | |
| relu_in_encoding=False, | |
| skip_in_encoding=True, | |
| c_mult=1, | |
| ): | |
| super(MeasCondBlock, self).__init__() | |
| self.separate_head = isinstance(img_channels, list) | |
| self.config = config | |
| assert img_channels is not None, "decode_dimensions should be provided" | |
| assert decode_upscale is not None, "decode_upscale should be provided" | |
| # if self.separate_head: | |
| if self.config == 'A': | |
| self.relu_encoding = nn.ReLU(inplace=False) | |
| self.N = N | |
| self.c_mult = c_mult | |
| self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding) | |
| if self.config == 'B': | |
| self.N = N | |
| self.c_mult = c_mult | |
| self.relu_encoding = nn.ReLU(inplace=False) | |
| self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) | |
| self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding) | |
| if self.config == 'C': | |
| self.N = N | |
| self.c_mult = c_mult | |
| self.relu_encoding = nn.ReLU(inplace=False) | |
| self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) | |
| self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding) | |
| elif self.config == 'D': | |
| self.N = N | |
| self.c_mult = c_mult | |
| self.relu_encoding = nn.ReLU(inplace=False) | |
| self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) | |
| self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding) | |
| self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
| self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
| self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
| self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
| self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
| def forward(self, x, y, physics, t, emb_in=None, img_channels=None, scale=1): | |
| if self.config == 'A': | |
| return self.measurement_conditioning_config_A(x, y, physics, img_channels=img_channels, scale=scale) | |
| elif self.config == 'F': | |
| return self.measurement_conditioning_config_F(x, y, physics, img_channels=img_channels, scale=scale) | |
| elif self.config == 'B': | |
| return self.measurement_conditioning_config_B(x, y, physics, img_channels=img_channels, scale=scale) | |
| elif self.config == 'C': | |
| return self.measurement_conditioning_config_C(x, y, physics, img_channels=img_channels, scale=scale) | |
| elif self.config == 'D': | |
| return self.measurement_conditioning_config_D(x, y, physics, img_channels=img_channels, scale=scale) | |
| elif self.config == 'E': | |
| return self.measurement_conditioning_config_E(x, y, physics, img_channels=img_channels, scale=scale) | |
| else: | |
| raise NotImplementedError('Config not implemented') | |
| def measurement_conditioning_config_A(self, x, y, physics, img_channels, scale=0): | |
| physics.set_scale(scale) | |
| factor = 2**(scale) | |
| meas = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) | |
| cond = self.encoding_conv(meas) | |
| emb = self.relu_encoding(cond) | |
| return emb | |
| def measurement_conditioning_config_B(self, x, y, physics, img_channels, scale=0): | |
| physics.set_scale(scale) | |
| dec = self.decoding_conv(x, img_channels) | |
| factor = 2**(scale) | |
| meas = krylov_embeddings(y, physics, factor, v=dec, N=self.N, img_channels=img_channels) | |
| cond = self.encoding_conv(meas) | |
| emb = self.relu_encoding(cond) | |
| return emb # * sigma_emb | |
| def measurement_conditioning_config_C(self, x, y, physics, img_channels, scale=0): | |
| physics.set_scale(scale) | |
| dec = self.decoding_conv(x, img_channels) | |
| factor = 2**(scale) | |
| meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) | |
| meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels) | |
| for c in range(1, self.c_mult): | |
| meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)], | |
| img_channels=img_channels) | |
| meas_dec = torch.cat([meas_dec, meas_cur], dim=1) | |
| meas = torch.cat([meas_y, meas_dec], dim=1) | |
| cond = self.encoding_conv(meas) | |
| emb = self.relu_encoding(cond) | |
| return emb | |
| def measurement_conditioning_config_D(self, x, y, physics, img_channels, scale=0): | |
| physics.set_scale(scale) | |
| dec = self.decoding_conv(x, img_channels) | |
| factor = 2**(scale) | |
| meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) | |
| meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels) | |
| for c in range(1, self.c_mult): | |
| meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)], | |
| img_channels=img_channels) | |
| meas_dec = torch.cat([meas_dec, meas_cur], dim=1) | |
| meas = torch.cat([meas_y, meas_dec], dim=1) | |
| cond = self.encoding_conv(meas) | |
| emb = self.relu_encoding(cond) | |
| return cond | |
| def measurement_conditioning_config_F(self, x, y, physics, img_channels): | |
| dec_large = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality) | |
| dec = self.relu_decoding(dec_large) | |
| Adec = physics.A(dec) | |
| grad = physics.A_adjoint(self.gain_gradx ** 2 * Adec - self.gain_grady ** 2 * y) # TODO: check if we need to have L2 (depending on noise nature, can be automated) | |
| if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower(): | |
| pinv = physics.prox_l2(dec, self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y, gamma=1e9) | |
| else: | |
| pinv = physics.A_dagger(self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess | |
| # Mix grad and pinv | |
| emb = grad - pinv # will be 0 in the case of denoising, but also inpainting | |
| im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too | |
| grad_large = emb + im_emb | |
| emb_grad = self.encoding_conv(grad_large) | |
| emb_grad = self.relu_encoding(emb_grad) | |
| return emb_grad | |
| def measurement_conditioning_config_E(self, x, y, physics, img_channels, scale=1): | |
| dec = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality) | |
| physics.set_scale(scale) | |
| # TODO: check things are batched | |
| f = physics.factor if hasattr(physics, "factor") else 1.0 | |
| err = (physics.A_adjoint(physics.A(dec) - y)) | |
| # snr = self.snr_module(err) | |
| snr = dec.reshape(dec.shape[0], -1).abs().mean(dim=1) / (err.reshape(err.shape[0], -1).abs().mean(dim=1) + 1e-4) | |
| gamma = 1 / (1e-4 + 1 / (snr * f ** 2 + 1)) # TODO: check square-root / mean / check if we need to add a factor in front | |
| gamma_est = gamma[(...,) + (None,) * (dec.dim() - 1)] | |
| prox = physics.prox_l2(dec, y, gamma=gamma_est * self.fact_prox) | |
| emb = self.fact_prox_skip_1 * prox + self.fact_prox_skip_2 * dec | |
| emb_grad = self.encoding_conv(emb) | |
| emb_grad = self.relu_encoding(emb_grad) | |
| return emb_grad | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=64, | |
| out_channels=64, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| bias=True, | |
| mode="CRC", | |
| negative_slope=0.2, | |
| embedding=False, | |
| emb_channels=None, | |
| emb_physics=False, | |
| img_channels=None, | |
| decode_upscale=None, | |
| config = 'A', | |
| head=False, | |
| tail=False, | |
| N=4, | |
| c_mult=1, | |
| depth_encoding=1, | |
| relu_in_encoding=False, | |
| skip_in_encoding=True, | |
| ): | |
| super(ResBlock, self).__init__() | |
| if not head and not tail: | |
| assert in_channels == out_channels, "Only support in_channels==out_channels." | |
| self.separate_head = isinstance(img_channels, list) | |
| self.config = config | |
| self.is_head = head | |
| self.is_tail = tail | |
| if self.is_head: | |
| self.head = InHead(img_channels, out_channels, input_layer=True) | |
| # if self.is_tail: | |
| # self.tail = OutTail(in_channels, out_channels) | |
| if not self.is_head and not self.is_tail: | |
| self.conv1 = conv( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| bias, | |
| "C", | |
| negative_slope, | |
| ) | |
| self.nl = nn.ReLU(inplace=True) | |
| self.conv2 = conv( | |
| out_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| bias, | |
| "C", | |
| negative_slope, | |
| ) | |
| if embedding: | |
| self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
| self.emb_linear = MPConv(emb_channels, out_channels, kernel=[]) | |
| self.emb_physics = emb_physics | |
| if self.emb_physics: | |
| self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
| self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, config=config, c_mult=c_mult, | |
| img_channels=img_channels, decode_upscale=decode_upscale, | |
| N=N, depth_encoding=depth_encoding, | |
| relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding) | |
| def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0): | |
| u = self.conv1(x) | |
| u = self.nl(u) | |
| u_2 = self.conv2(u) # Should we sum this with below? | |
| if self.emb_physics: # TODO: add a factor (1+gain) to the emb_meas? that depends on the input snr | |
| emb_grad = self.PhysicsBlock(u, y, physics, t, img_channels=img_channels, scale=scale) | |
| u_1 = self.gain * emb_grad # x - grad (sign does not matter) | |
| else: | |
| u_1 = 0 | |
| return x + u_2 + u_1 | |
| def calculate_fan_in_and_fan_out(tensor, pytorch_style: bool = True): | |
| """ | |
| from https://github.com/megvii-research/basecls/blob/main/basecls/layers/wrapper.py#L77 | |
| """ | |
| if len(tensor.shape) not in (2, 4, 5): | |
| raise ValueError( | |
| "fan_in and fan_out can only be computed for tensor with 2/4/5 " | |
| "dimensions" | |
| ) | |
| if len(tensor.shape) == 5: | |
| # `GOIKK` to `OIKK` | |
| tensor = tensor.reshape(-1, *tensor.shape[2:]) if pytorch_style else tensor[0] | |
| num_input_fmaps = tensor.shape[1] | |
| num_output_fmaps = tensor.shape[0] | |
| receptive_field_size = 1 | |
| if len(tensor.shape) > 2: | |
| receptive_field_size = functools.reduce(lambda x, y: x * y, tensor.shape[2:], 1) | |
| fan_in = num_input_fmaps * receptive_field_size | |
| fan_out = num_output_fmaps * receptive_field_size | |
| return fan_in, fan_out | |
| def weights_init_unext(m, gain_conv=1.0, gain_linear=1.0, init_type="ortho"): | |
| if hasattr(m, "modules"): | |
| for submodule in m.modules(): | |
| if not 'skip' in str(submodule): | |
| if isinstance(submodule, nn.Conv2d) or isinstance( | |
| submodule, nn.ConvTranspose2d | |
| ): | |
| # nn.init.orthogonal_(submodule.weight.data, gain=1.0) | |
| k_shape = submodule.weight.data.shape[-1] | |
| if k_shape < 4: | |
| nn.init.orthogonal_(submodule.weight.data, gain=0.2) | |
| else: | |
| _, fan_out = calculate_fan_in_and_fan_out(submodule.weight) | |
| std = math.sqrt(2 / fan_out) | |
| nn.init.normal_(submodule.weight, 0, std) | |
| # if init_type == 'ortho': | |
| # nn.init.orthogonal_(submodule.weight.data, gain=gain_conv) | |
| # elif init_type == 'kaiming': | |
| # nn.init.kaiming_normal_(submodule.weight.data, a=0, mode='fan_in') | |
| # elif init_type == 'xavier': | |
| # nn.init.xavier_normal_(submodule.weight.data, gain=gain_conv) | |
| elif isinstance(submodule, nn.Linear): | |
| nn.init.normal_(submodule.weight.data, std=0.01) | |
| elif 'skip' in str(submodule): | |
| if isinstance(submodule, nn.Conv2d) or isinstance( | |
| submodule, nn.ConvTranspose2d | |
| ): | |
| nn.init.ones_(submodule.weight.data) | |
| # else: | |
| # classname = submodule.__class__.__name__ | |
| # # print('WARNING: no init for ', classname) | |
| def old2new(old_key): | |
| """ | |
| Converting old DRUNet keys to new UNExt style keys. | |
| PATTERNS TO MATCH: | |
| 1. Case of downsampling blocks: | |
| - for residual blocks (non-downsampling): | |
| m_down3.2.res.0.weight -> m_down3.enc.2.conv1.weight | |
| - for downsampling blocks: | |
| m_down3.4.weight -> m_down3.downsample_strideconv.weight | |
| 2. Case of upsampling blocks: | |
| - for upsampling: | |
| m_up3.0.weight -> m_up3.upsample_convtranspose.weight | |
| - for residual blocks: | |
| m_up3.2.res.0.weight -> m_up3.enc.2.conv1.weight | |
| 3. Case for body blocks: | |
| m_body.0.res.2.weight -> m_body.enc.0.conv2.weight | |
| Args: | |
| old_key (str): The old key from the state dictionary. | |
| Returns: | |
| str or None: The new key if matched, otherwise None. | |
| """ | |
| # Match keys with the pattern for residual blocks (downsampling) | |
| match_residual = re.search(r"(m_down\d+)\.(\d+)\.res\.(\d+)", old_key) | |
| if match_residual: | |
| prefix = match_residual.group(1) # e.g., "m_down2" | |
| index = match_residual.group(2) # e.g., "3" | |
| conv_index = int(match_residual.group(3)) # e.g., "0" | |
| # Determine the new conv index: 0 -> 1, 2 -> 2 | |
| new_conv_index = 1 if conv_index == 0 else 2 | |
| # Construct the new key | |
| new_key = f"{prefix}.enc.{index}.conv{new_conv_index}.weight" | |
| return new_key | |
| match_residual = re.search(r"(m_up\d+)\.(\d+)\.res\.(\d+)", old_key) | |
| if match_residual: | |
| prefix = match_residual.group(1) # e.g., "m_down2" | |
| index = int(match_residual.group(2)) # e.g., "3" | |
| conv_index = int(match_residual.group(3)) # e.g., "0" | |
| # Determine the new conv index: 0 -> 1, 2 -> 2 | |
| new_conv_index = 1 if conv_index == 0 else 2 | |
| # Construct the new key | |
| new_key = f"{prefix}.enc.{index-1}.conv{new_conv_index}.weight" | |
| return new_key | |
| match_pool_downsample = re.search(r"m_down(\d+)\.4\.weight", old_key) | |
| if match_pool_downsample: | |
| index = match_pool_downsample.group(1) # e.g., "1" or "2" | |
| # Construct the new key | |
| new_key = f"pool{index}.weight" | |
| return new_key | |
| # Match keys for upsampling blocks | |
| match_upsample = re.search(r"m_up(\d+)\.0\.weight", old_key) | |
| if match_upsample: | |
| index = match_upsample.group(1) # e.g., "1" or "2" | |
| # Construct the new key | |
| new_key = f"up{index}.weight" | |
| return new_key | |
| # Match keys for body blocks | |
| match_body = re.search(r"(m_body)\.(\d+)\.res\.(\d+)\.weight", old_key) | |
| if match_body: | |
| prefix = match_body.group(1) # e.g., "m_body" | |
| index = match_body.group(2) # e.g., "0" | |
| conv_index = int(match_body.group(3)) # e.g., "2" | |
| new_convindex = 1 if conv_index == 0 else 2 | |
| # Construct the new key | |
| new_key = f"{prefix}.enc.{index}.conv{new_convindex}.weight" | |
| return new_key | |
| # If no patterns match, return None | |
| return None | |
| def update_keyvals_headtail(old_key, old_value, init_value=None, new_key_name='m_head.conv0.weight', conditioning='base'): | |
| """ | |
| Converting old DRUNet keys to new UNExt style keys. | |
| KEYS do not change but weight need to be 0 padded. | |
| Args: | |
| old_key (str): The old key from the state dictionary. | |
| """ | |
| if 'head' in old_key: | |
| if conditioning == 'base': | |
| c_in = init_value.shape[1] | |
| c_in_old = old_value.shape[1] | |
| # if c_in == c_in_old: | |
| # new_value = old_value.detach() | |
| # elif c_in < c_in_old: | |
| # new_value = torch.zeros_like(init_value.detach()) | |
| # new_value[:, -1:, ...] = old_value[:, -1:, ...] | |
| # new_value[:, :c_in-1, ...] = old_value[:, :c_in-1, ...] | |
| # if c_in == c_in_old: | |
| # new_value = old_value.detach() | |
| # elif c_in < c_in_old: | |
| new_value = torch.zeros_like(init_value.detach()) | |
| new_value[:, -2:-1, ...] = old_value[:, -1:, ...] | |
| new_value[:, -1:, ...] = old_value[:, -1:, ...] | |
| new_value[:, :c_in-2, ...] = old_value[:, :c_in-2, ...] | |
| return {new_key_name: new_value} | |
| else: | |
| c_in = init_value.shape[1] | |
| c_in_old = old_value.shape[1] | |
| # if c_in == c_in_old - 1: | |
| # new_value = old_value[:, :-1, ...].detach() | |
| # elif c_in < c_in_old - 1: | |
| # new_value = torch.zeros_like(init_value.detach()) | |
| # new_value[:, -1:, ...] = old_value[:, -1:, ...] | |
| # new_value[:, ...] = old_value[:, :c_in, ...] | |
| new_value = torch.zeros_like(init_value.detach()) | |
| new_value[:, -1:-2, ...] = old_value[:, -1:, ...] | |
| new_value[:, -1:, ...] = old_value[:, -1:, ...] | |
| new_value[:, ...] = old_value[:, :c_in, ...] | |
| return {new_key_name: new_value} | |
| elif 'tail' in old_key: | |
| c_in = init_value.shape[0] | |
| c_in_old = old_value.shape[0] | |
| new_value = torch.zeros_like(init_value.detach()) | |
| if c_in == c_in_old: | |
| new_value = old_value.detach() | |
| elif c_in < c_in_old: | |
| new_value = torch.zeros_like(init_value.detach()) | |
| new_value[:, ...] = old_value[:c_in, ...] | |
| return {new_key_name: new_value} | |
| else: | |
| print(f"Key {old_key} does not contain 'head' or 'tail'.") | |
| # test the network | |
| if __name__ == "__main__": | |
| net = UNeXt() | |
| x = torch.randn(1, 3, 128, 128) | |
| y = net(x, 0.1) | |
| # print(y.shape) | |
| # print(y) | |
| # Case for diagonal physics | |
| # IDEA 1: kills signal in the image of A | |
| # im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too | |
| # IDEA 2: compute norm of signal in ker of A | |
| # normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4) | |
| # im_emb = normker * physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too | |
| # IDEA 3: same as above but add the pinv as well | |
| # normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4) | |
| # grad_term = physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) | |
| # # pinv_term = physics.A_dagger(self.gain_diagpinv_x * physics.A(dec) - self.gain_diagpinv_y * y) | |
| # if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower(): | |
| # pinv_term = physics.prox_l2(dec, self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y, gamma=1e9) | |
| # else: | |
| # pinv_term = physics.A_dagger(self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess | |
| # im_emb = normker * (grad_term + pinv_term) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too | |
| # # Mix it | |
| # if hasattr(physics.noise_model, 'sigma'): | |
| # sigma = physics.noise_model.sigma # SNR ? x /= sigma ** 2 | |
| # snr = (y.abs().mean()) / (sigma + 1e-4) # SNR equivariant # TODO: add epsilon | |
| # snr = snr[(...,) + (None,) * (im_emb.dim() - 1)] | |
| # else: | |
| # snr = 1e4 | |
| # | |
| # grad_large = emb + self.gain_diag * (1 + self.gain_noise / snr) * im_emb |