File size: 5,420 Bytes
8fbac9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from typing import List
import torch
import torch.nn as nn
from .clip_explainability import load
from .clip import tokenize
from torch import device
import numpy as np
import torch.nn.functional as nnf
import itertools


def zeroshot_classifier(clip_model, classnames, templates, device):
    with torch.no_grad():
        texts = list(
            itertools.chain(
                *[
                    [template.format(classname) for template in templates]
                    for classname in classnames
                ]
            )
        )  # format with class
        texts = tokenize(texts).to(device)  # tokenize
        class_embeddings = clip_model.encode_text(texts)
        class_embeddings = class_embeddings.view(len(classnames), len(templates), -1)
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        zeroshot_weights = class_embeddings.mean(dim=1)
        return zeroshot_weights.T  # shape: [dim, n classes]


class ClipGradcam(nn.Module):
    def __init__(
        self,
        clip_model_name: str,
        classes: List[str],
        templates: List[str],
        device: device,
        num_layers=10,
        positive_attn_only=False,
        **kwargs
    ):

        super(ClipGradcam, self).__init__()
        self.clip_model_name = clip_model_name
        self.model, self.preprocess = load(clip_model_name, device=device, **kwargs)
        self.templates = templates
        self.device = device
        self.target_classes = None
        self.set_classes(classes)
        self.num_layers = num_layers
        self.positive_attn_only = positive_attn_only
        self.num_res_attn_blocks = {
            "ViT-B/32": 12,
            "ViT-B/16": 12,
            "ViT-L/14": 16,
            "ViT-L/14@336px": 16,
        }[clip_model_name]

    def forward(self, x: torch.Tensor, o: List[str]):
        """
        non-standard hack around an nn, really should be more principled here
        """
        image_features = self.model.encode_image(x.to(self.device))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        zeroshot_weights = torch.cat(
            [self.class_to_language_feature[prompt] for prompt in o], dim=1
        )
        logits_per_image = 100.0 * image_features @ zeroshot_weights
        return self.interpret(logits_per_image, self.model, self.device)

    def interpret(self, logits_per_image, model, device):
        # modified from: https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb#scrollTo=fWKGyu2YAeSV
        batch_size = logits_per_image.shape[0]
        num_prompts = logits_per_image.shape[1]
        one_hot = [logit for logit in logits_per_image.sum(dim=0)]
        model.zero_grad()

        image_attn_blocks = list(
            dict(model.visual.transformer.resblocks.named_children()).values()
        )
        num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
        R = torch.eye(
            num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
        ).to(device)
        R = R[None, None, :, :].repeat(num_prompts, batch_size, 1, 1)
        for i, block in enumerate(image_attn_blocks):
            if i <= self.num_layers:
                continue
            # TODO try scaling block.attn_probs by value magnitude
            # TODO actual parallelized prompt gradients
            grad = torch.stack(
                [
                    torch.autograd.grad(logit, [block.attn_probs], retain_graph=True)[
                        0
                    ].detach()
                    for logit in one_hot
                ]
            )
            grad = grad.view(
                num_prompts,
                batch_size,
                self.num_res_attn_blocks,
                num_tokens,
                num_tokens,
            )
            cam = (
                block.attn_probs.view(
                    1, batch_size, self.num_res_attn_blocks, num_tokens, num_tokens
                )
                .detach()
                .repeat(num_prompts, 1, 1, 1, 1)
            )
            cam = cam.reshape(num_prompts, batch_size, -1, cam.shape[-1], cam.shape[-1])
            grad = grad.reshape(
                num_prompts, batch_size, -1, grad.shape[-1], grad.shape[-1]
            )
            cam = grad * cam
            cam = cam.reshape(
                num_prompts * batch_size, -1, cam.shape[-1], cam.shape[-1]
            )
            if self.positive_attn_only:
                cam = cam.clamp(min=0)
            # average of all heads
            cam = cam.mean(dim=-3)
            R = R + torch.bmm(
                cam, R.view(num_prompts * batch_size, num_tokens, num_tokens)
            ).view(num_prompts, batch_size, num_tokens, num_tokens)
        image_relevance = R[:, :, 0, 1:]
        img_dim = int(np.sqrt(num_tokens - 1))
        image_relevance = image_relevance.reshape(
            num_prompts, batch_size, img_dim, img_dim
        )
        return image_relevance

    def set_classes(self, classes):
        self.target_classes = classes
        language_features = zeroshot_classifier(
            self.model, self.target_classes, self.templates, self.device
        )

        self.class_to_language_feature = {}
        for i, c in enumerate(self.target_classes):
            self.class_to_language_feature[c] = language_features[:, [i]]