Spaces:
Paused
Paused
import torch | |
import torchvision | |
import torch.nn.functional as F | |
def attn_cosine_sim(x, eps=1e-08): | |
x = x[0] # TEMP: getting rid of redundant dimension, TBF | |
norm1 = x.norm(dim=2, keepdim=True) | |
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps) | |
sim_matrix = (x @ x.permute(0, 2, 1)) / factor | |
return sim_matrix | |
class VitExtractor: | |
BLOCK_KEY = 'block' | |
ATTN_KEY = 'attn' | |
PATCH_IMD_KEY = 'patch_imd' | |
QKV_KEY = 'qkv' | |
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY] | |
def __init__(self, model_name, device): | |
# pdb.set_trace() | |
self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device) | |
self.model.eval() | |
self.model_name = model_name | |
self.hook_handlers = [] | |
self.layers_dict = {} | |
self.outputs_dict = {} | |
for key in VitExtractor.KEY_LIST: | |
self.layers_dict[key] = [] | |
self.outputs_dict[key] = [] | |
self._init_hooks_data() | |
def _init_hooks_data(self): | |
self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] | |
self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] | |
self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] | |
self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] | |
for key in VitExtractor.KEY_LIST: | |
# self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else [] | |
self.outputs_dict[key] = [] | |
def _register_hooks(self, **kwargs): | |
for block_idx, block in enumerate(self.model.blocks): | |
if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]: | |
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook())) | |
if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]: | |
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook())) | |
if block_idx in self.layers_dict[VitExtractor.QKV_KEY]: | |
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook())) | |
if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]: | |
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook())) | |
def _clear_hooks(self): | |
for handler in self.hook_handlers: | |
handler.remove() | |
self.hook_handlers = [] | |
def _get_block_hook(self): | |
def _get_block_output(model, input, output): | |
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output) | |
return _get_block_output | |
def _get_attn_hook(self): | |
def _get_attn_output(model, inp, output): | |
self.outputs_dict[VitExtractor.ATTN_KEY].append(output) | |
return _get_attn_output | |
def _get_qkv_hook(self): | |
def _get_qkv_output(model, inp, output): | |
self.outputs_dict[VitExtractor.QKV_KEY].append(output) | |
return _get_qkv_output | |
# TODO: CHECK ATTN OUTPUT TUPLE | |
def _get_patch_imd_hook(self): | |
def _get_attn_output(model, inp, output): | |
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0]) | |
return _get_attn_output | |
def get_feature_from_input(self, input_img): # List([B, N, D]) | |
self._register_hooks() | |
self.model(input_img) | |
feature = self.outputs_dict[VitExtractor.BLOCK_KEY] | |
self._clear_hooks() | |
self._init_hooks_data() | |
return feature | |
def get_qkv_feature_from_input(self, input_img): | |
self._register_hooks() | |
self.model(input_img) | |
feature = self.outputs_dict[VitExtractor.QKV_KEY] | |
self._clear_hooks() | |
self._init_hooks_data() | |
return feature | |
def get_attn_feature_from_input(self, input_img): | |
self._register_hooks() | |
self.model(input_img) | |
feature = self.outputs_dict[VitExtractor.ATTN_KEY] | |
self._clear_hooks() | |
self._init_hooks_data() | |
return feature | |
def get_patch_size(self): | |
return 8 if "8" in self.model_name else 16 | |
def get_width_patch_num(self, input_img_shape): | |
b, c, h, w = input_img_shape | |
patch_size = self.get_patch_size() | |
return w // patch_size | |
def get_height_patch_num(self, input_img_shape): | |
b, c, h, w = input_img_shape | |
patch_size = self.get_patch_size() | |
return h // patch_size | |
def get_patch_num(self, input_img_shape): | |
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape)) | |
return patch_num | |
def get_head_num(self): | |
if "dino" in self.model_name: | |
return 6 if "s" in self.model_name else 12 | |
return 6 if "small" in self.model_name else 12 | |
def get_embedding_dim(self): | |
if "dino" in self.model_name: | |
return 384 if "s" in self.model_name else 768 | |
return 384 if "small" in self.model_name else 768 | |
def get_queries_from_qkv(self, qkv, input_img_shape): | |
patch_num = self.get_patch_num(input_img_shape) | |
head_num = self.get_head_num() | |
embedding_dim = self.get_embedding_dim() | |
q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0] | |
return q | |
def get_keys_from_qkv(self, qkv, input_img_shape): | |
patch_num = self.get_patch_num(input_img_shape) | |
head_num = self.get_head_num() | |
embedding_dim = self.get_embedding_dim() | |
k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1] | |
return k | |
def get_values_from_qkv(self, qkv, input_img_shape): | |
patch_num = self.get_patch_num(input_img_shape) | |
head_num = self.get_head_num() | |
embedding_dim = self.get_embedding_dim() | |
v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2] | |
return v | |
def get_keys_from_input(self, input_img, layer_num): | |
qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num] | |
keys = self.get_keys_from_qkv(qkv_features, input_img.shape) | |
return keys | |
def get_keys_self_sim_from_input(self, input_img, layer_num): | |
keys = self.get_keys_from_input(input_img, layer_num=layer_num) | |
h, t, d = keys.shape | |
concatenated_keys = keys.transpose(0, 1).reshape(t, h * d) | |
ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...]) | |
return ssim_map | |
class DinoStructureLoss: | |
def __init__(self, ): | |
self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda") | |
self.preprocess = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize(224), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
]) | |
def calculate_global_ssim_loss(self, outputs, inputs): | |
loss = 0.0 | |
for a, b in zip(inputs, outputs): # avoid memory limitations | |
with torch.no_grad(): | |
target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11) | |
keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11) | |
loss += F.mse_loss(keys_ssim, target_keys_self_sim) | |
return loss | |