Spaces:
Runtime error
Runtime error
File size: 3,215 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/zejiangh/MILAN
from typing import Optional, Tuple
import torch
from mmengine.model import BaseModule
from torch import nn
from mmpretrain.models.utils.clip_generator_helper import \
ResidualAttentionBlock
from mmpretrain.registry import MODELS
@MODELS.register_module()
class CLIPTransformer(nn.Module):
"""Transformer.
Both visual and text branches use this transformer.
Args:
width (int): The feature dimension.
layers (int): The number of layers.
heads (int): The number of attention heads.
attn_mask (torch.Tensor, optional): The attention mask.
"""
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: Optional[torch.Tensor] = None) -> None:
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList()
for _ in range(layers - 1):
self.resblocks.append(
ResidualAttentionBlock(width, heads, attn_mask))
self.resblocks.append(
ResidualAttentionBlock(
width, heads, attn_mask, return_attention=True))
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward function."""
z = []
for idx, blk in enumerate(self.resblocks):
if idx < self.layers - 1:
x = blk(x)
z.append(x.permute(1, 0, 2))
else:
x, attention = blk(x)
z.append(x.permute(1, 0, 2))
return x, attention, z
@MODELS.register_module()
class CLIPProjection(BaseModule):
"""Neck with CLIP Projection.
Args:
in_channels (int): Number of channels in the input.
out_channels (int): Number of channels in the output.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
init_cfg: Optional[dict] = None):
super(CLIPProjection, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
scale = in_channels**-0.5
self.proj = nn.Parameter(scale *
torch.randn(in_channels, out_channels))
def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]:
"""forward function.
Args:
inputs (Tuple): The features extracted from
the backbone. Multiple stage inputs are acceptable but only
the last stage will be used.
Returns:
Tuple(torch.Tensor)): A tuple of reducted features.
"""
if isinstance(inputs, tuple):
inputs = inputs[-1]
out = inputs @ self.proj
elif isinstance(inputs, torch.Tensor):
out = inputs @ self.proj
else:
raise TypeError(
'`CLIPProjection` neck inputs should be tuple or torch.tensor')
return (out, )
|