import math import torch import torch.nn.functional as F from einops import rearrange from torch import nn from torch.nn import Parameter class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class Mish(nn.Module): def forward(self, x): return x * torch.tanh(F.softplus(x)) class Rezero(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn self.g = nn.Parameter(torch.zeros(1)) def forward(self, x): return self.fn(x) * self.g # building block modules class Block(nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() if groups == 0: self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(dim, dim_out, 3), Mish() ) else: self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(dim, dim_out, 3), nn.GroupNorm(groups, dim_out), Mish() ) def forward(self, x): return self.block(x) class ResnetBlock(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=0, groups=8): super().__init__() if time_emb_dim > 0: self.mlp = nn.Sequential( Mish(), nn.Linear(time_emb_dim, dim_out) ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, cond=None): h = self.block1(x) if time_emb is not None: h += self.mlp(time_emb)[:, :, None, None] if cond is not None: h += cond h = self.block2(h) return h + self.res_conv(x) class Upsample(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Sequential( nn.ConvTranspose2d(dim, dim, 4, 2, 1), ) def forward(self, x): return self.conv(x) class Downsample(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, 3, 2), ) def forward(self, x): return self.conv(x) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) return self.to_out(out) class MultiheadAttention(nn.Module): def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False): super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim ** -0.5 if self.qkv_same_dim: self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) else: self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) if bias: self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) else: self.register_parameter('in_proj_bias', None) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) if add_bias_kv: self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self.reset_parameters() self.enable_torch_version = False if hasattr(F, "multi_head_attention_forward"): self.enable_torch_version = True else: self.enable_torch_version = False self.last_attn_probs = None def reset_parameters(self): if self.qkv_same_dim: nn.init.xavier_uniform_(self.in_proj_weight) else: nn.init.xavier_uniform_(self.k_proj_weight) nn.init.xavier_uniform_(self.v_proj_weight) nn.init.xavier_uniform_(self.q_proj_weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.in_proj_bias is not None: nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.out_proj.bias, 0.) if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k) if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v) def forward( self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, before_softmax=False, need_head_weights=False, ): """Input shape: [B, T, C] Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True query = query.transpose(0, 1) key = key.transpose(0, 1) value = value.transpose(0, 1) tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, self.training, key_padding_mask, need_weights, attn_mask) attn_output = attn_output.transpose(0, 1) return attn_output, attn_output_weights def in_proj_qkv(self, query): return self._in_proj(query).chunk(3, dim=-1) def in_proj_q(self, query): if self.qkv_same_dim: return self._in_proj(query, end=self.embed_dim) else: bias = self.in_proj_bias if bias is not None: bias = bias[:self.embed_dim] return F.linear(query, self.q_proj_weight, bias) def in_proj_k(self, key): if self.qkv_same_dim: return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) else: weight = self.k_proj_weight bias = self.in_proj_bias if bias is not None: bias = bias[self.embed_dim:2 * self.embed_dim] return F.linear(key, weight, bias) def in_proj_v(self, value): if self.qkv_same_dim: return self._in_proj(value, start=2 * self.embed_dim) else: weight = self.v_proj_weight bias = self.in_proj_bias if bias is not None: bias = bias[2 * self.embed_dim:] return F.linear(value, weight, bias) def _in_proj(self, input, start=0, end=None): weight = self.in_proj_weight bias = self.in_proj_bias weight = weight[start:end, :] if bias is not None: bias = bias[start:end] return F.linear(input, weight, bias) class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): '''Residual in Residual Dense Block''' def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x