Spaces:
Build error
Build error
File size: 5,420 Bytes
8fbac9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from typing import List
import torch
import torch.nn as nn
from .clip_explainability import load
from .clip import tokenize
from torch import device
import numpy as np
import torch.nn.functional as nnf
import itertools
def zeroshot_classifier(clip_model, classnames, templates, device):
with torch.no_grad():
texts = list(
itertools.chain(
*[
[template.format(classname) for template in templates]
for classname in classnames
]
)
) # format with class
texts = tokenize(texts).to(device) # tokenize
class_embeddings = clip_model.encode_text(texts)
class_embeddings = class_embeddings.view(len(classnames), len(templates), -1)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
zeroshot_weights = class_embeddings.mean(dim=1)
return zeroshot_weights.T # shape: [dim, n classes]
class ClipGradcam(nn.Module):
def __init__(
self,
clip_model_name: str,
classes: List[str],
templates: List[str],
device: device,
num_layers=10,
positive_attn_only=False,
**kwargs
):
super(ClipGradcam, self).__init__()
self.clip_model_name = clip_model_name
self.model, self.preprocess = load(clip_model_name, device=device, **kwargs)
self.templates = templates
self.device = device
self.target_classes = None
self.set_classes(classes)
self.num_layers = num_layers
self.positive_attn_only = positive_attn_only
self.num_res_attn_blocks = {
"ViT-B/32": 12,
"ViT-B/16": 12,
"ViT-L/14": 16,
"ViT-L/14@336px": 16,
}[clip_model_name]
def forward(self, x: torch.Tensor, o: List[str]):
"""
non-standard hack around an nn, really should be more principled here
"""
image_features = self.model.encode_image(x.to(self.device))
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
zeroshot_weights = torch.cat(
[self.class_to_language_feature[prompt] for prompt in o], dim=1
)
logits_per_image = 100.0 * image_features @ zeroshot_weights
return self.interpret(logits_per_image, self.model, self.device)
def interpret(self, logits_per_image, model, device):
# modified from: https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb#scrollTo=fWKGyu2YAeSV
batch_size = logits_per_image.shape[0]
num_prompts = logits_per_image.shape[1]
one_hot = [logit for logit in logits_per_image.sum(dim=0)]
model.zero_grad()
image_attn_blocks = list(
dict(model.visual.transformer.resblocks.named_children()).values()
)
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(
num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
).to(device)
R = R[None, None, :, :].repeat(num_prompts, batch_size, 1, 1)
for i, block in enumerate(image_attn_blocks):
if i <= self.num_layers:
continue
# TODO try scaling block.attn_probs by value magnitude
# TODO actual parallelized prompt gradients
grad = torch.stack(
[
torch.autograd.grad(logit, [block.attn_probs], retain_graph=True)[
0
].detach()
for logit in one_hot
]
)
grad = grad.view(
num_prompts,
batch_size,
self.num_res_attn_blocks,
num_tokens,
num_tokens,
)
cam = (
block.attn_probs.view(
1, batch_size, self.num_res_attn_blocks, num_tokens, num_tokens
)
.detach()
.repeat(num_prompts, 1, 1, 1, 1)
)
cam = cam.reshape(num_prompts, batch_size, -1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(
num_prompts, batch_size, -1, grad.shape[-1], grad.shape[-1]
)
cam = grad * cam
cam = cam.reshape(
num_prompts * batch_size, -1, cam.shape[-1], cam.shape[-1]
)
if self.positive_attn_only:
cam = cam.clamp(min=0)
# average of all heads
cam = cam.mean(dim=-3)
R = R + torch.bmm(
cam, R.view(num_prompts * batch_size, num_tokens, num_tokens)
).view(num_prompts, batch_size, num_tokens, num_tokens)
image_relevance = R[:, :, 0, 1:]
img_dim = int(np.sqrt(num_tokens - 1))
image_relevance = image_relevance.reshape(
num_prompts, batch_size, img_dim, img_dim
)
return image_relevance
def set_classes(self, classes):
self.target_classes = classes
language_features = zeroshot_classifier(
self.model, self.target_classes, self.templates, self.device
)
self.class_to_language_feature = {}
for i, c in enumerate(self.target_classes):
self.class_to_language_feature[c] = language_features[:, [i]]
|