File size: 3,215 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/zejiangh/MILAN
from typing import Optional, Tuple

import torch
from mmengine.model import BaseModule
from torch import nn

from mmpretrain.models.utils.clip_generator_helper import \
    ResidualAttentionBlock
from mmpretrain.registry import MODELS


@MODELS.register_module()
class CLIPTransformer(nn.Module):
    """Transformer.

    Both visual and text branches use this transformer.

    Args:
        width (int): The feature dimension.
        layers (int): The number of layers.
        heads (int): The number of attention heads.
        attn_mask (torch.Tensor, optional): The attention mask.
    """

    def __init__(self,
                 width: int,
                 layers: int,
                 heads: int,
                 attn_mask: Optional[torch.Tensor] = None) -> None:
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.ModuleList()
        for _ in range(layers - 1):
            self.resblocks.append(
                ResidualAttentionBlock(width, heads, attn_mask))
        self.resblocks.append(
            ResidualAttentionBlock(
                width, heads, attn_mask, return_attention=True))

    def forward(
            self, x: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward function."""
        z = []
        for idx, blk in enumerate(self.resblocks):
            if idx < self.layers - 1:
                x = blk(x)
                z.append(x.permute(1, 0, 2))
            else:
                x, attention = blk(x)
                z.append(x.permute(1, 0, 2))
        return x, attention, z


@MODELS.register_module()
class CLIPProjection(BaseModule):
    """Neck with CLIP Projection.

    Args:
        in_channels (int): Number of channels in the input.
        out_channels (int): Number of channels in the output.
        init_cfg (dict | list[dict], optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 init_cfg: Optional[dict] = None):
        super(CLIPProjection, self).__init__(init_cfg=init_cfg)

        self.in_channels = in_channels
        self.out_channels = out_channels
        scale = in_channels**-0.5
        self.proj = nn.Parameter(scale *
                                 torch.randn(in_channels, out_channels))

    def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]:
        """forward function.

        Args:
            inputs (Tuple): The features extracted from
                 the backbone. Multiple stage inputs are acceptable but only
                  the last stage will be used.
        Returns:
            Tuple(torch.Tensor)): A tuple of reducted features.
        """
        if isinstance(inputs, tuple):
            inputs = inputs[-1]
            out = inputs @ self.proj
        elif isinstance(inputs, torch.Tensor):
            out = inputs @ self.proj
        else:
            raise TypeError(
                '`CLIPProjection` neck inputs should be tuple or torch.tensor')
        return (out, )