File size: 5,450 Bytes
e0d8c59
 
 
2c243ae
 
e0d8c59
2c243ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bba327
e0d8c59
2c243ae
 
 
 
 
e0d8c59
2c243ae
 
 
e0d8c59
 
2c243ae
 
 
 
 
 
 
5bba327
 
 
 
 
e0d8c59
2c243ae
5bba327
2c243ae
 
 
 
 
 
 
 
 
 
 
1a33535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import timm
import gradio as gr
from huggingface_hub import hf_hub_download
import os
from ViT.ViT_new import vit_base_patch16_224 as vit
import torchvision.transforms as transforms
import requests
from PIL import Image
import numpy as np
import cv2


# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

start_layer = 0

# rule 5 from paper
def avg_heads(cam, grad):
    cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
    grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
    cam = grad * cam
    cam = cam.clamp(min=0).mean(dim=0)
    return cam

# rule 6 from paper
def apply_self_attention_rules(R_ss, cam_ss):
    R_ss_addition = torch.matmul(cam_ss, R_ss)
    return R_ss_addition

def generate_relevance(model, input, index=None):
    output = model(input, register_hook=True)
    if index == None:
        index = np.argmax(output.cpu().data.numpy(), axis=-1)

    one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot_vector = one_hot
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot * output)
    model.zero_grad()
    one_hot.backward(retain_graph=True)

    num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]
    R = torch.eye(num_tokens, num_tokens)
    for i,blk in enumerate(model.blocks):
        if i < start_layer:
            continue
        grad = blk.attn.get_attn_gradients()
        cam = blk.attn.get_attention_map()
        cam = avg_heads(cam, grad)
        R += apply_self_attention_rules(R, cam)
    return R[0, 1:]

def generate_visualization(model, original_image, class_index=None):
    with torch.enable_grad():
        transformer_attribution = generate_relevance(model, original_image.unsqueeze(0), index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis
	
model_finetuned = None
model = None

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_224 = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def image_classifier(inp):
	image = transform_224(inp)
	print(image.shape)
	#return model_finetuned(image.unsqueeze(0))
	with torch.no_grad():
		prediction = torch.nn.functional.softmax(model_finetuned(image.unsqueeze(0))[0], dim=0)
		confidences = {labels[i]: float(prediction[i]) for i in range(1000)}    
		heatmap = generate_visualization(model_finetuned, image)
		
		prediction_orig = torch.nn.functional.softmax(model(image.unsqueeze(0))[0], dim=0)
		confidences_orig = {labels[i]: float(prediction_orig[i]) for i in range(1000)}    
		heatmap_orig = generate_visualization(model, image)
	return confidences, heatmap, confidences_orig, heatmap_orig

def _load_model(model_name: str):
	global model_finetuned, model
	path = hf_hub_download('Hila/RobustViT',
						   f'{model_name}')
						   
	model = vit(pretrained=True)
	model.eval()
	model_finetuned = vit()
	checkpoint = torch.load(path, map_location='cpu')
	model_finetuned.load_state_dict(checkpoint['state_dict'])
	model_finetuned.eval()
	
_load_model('ar_base.tar')
#demo = gr.Interface(image_classifier, gr.inputs.Image(shape=(224,224)), [gr.outputs.Label(label="Our Classification", num_top_classes=3), gr.Image(label="Our Relevance",shape=(64,64)), gr.outputs.Label(label="Original Classification", num_top_classes=3), gr.Image(label="Original Relevance",shape=(64,64))],examples=["samples/augreg_base/tank.png", "samples/augreg_base/sundial.png", "samples/augreg_base/lizard.png", "samples/augreg_base/storck.png", "samples/augreg_base/hummingbird2.png", "samples/augreg_base/hummingbird.png"], capture_session=True)
#demo.launch(debug=True)

demo = gr.Blocks()

with demo:
	gr.Markdown('Select an image and then click **submit** to see the output.')
	
	with gr.Row():
		inp = gr.inputs.Image(shape=(224,224))
	
	outs = []
	with gr.Row():
		out1 = gr.outputs.Label(label="Our Classification", num_top_classes=3)
		out2 = gr.Image(label="Our Relevance",shape=(224,224))
	
	with gr.Row():
		out3 = gr.outputs.Label(label="Original Classification", num_top_classes=3)
		out4 = gr.Image(label="Original Relevance",shape=(224,224))
	
	btn = gr.Button('Submit')
	btn.click(fn=image_classifier, inputs=inp, outputs=[out1, out2, out3, out4])

demo.launch()