|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
import peft |
|
from timm.models import create_model, safe_model_name |
|
from timm.layers import trunc_normal_, Mlp |
|
|
|
import sys |
|
|
|
from .to_pixel import ToPixel |
|
|
|
from .vision_transformer import Attention, RoPEAttention |
|
|
|
import math |
|
|
|
|
|
class DINOv2Encoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels=3, |
|
num_latent_tokens=32, |
|
use_attn_mask=False, |
|
model_name="vit_small_patch14_dinov2.lvd142m", |
|
model_kwargs={ |
|
"img_size": 224, |
|
"patch_size": 14, |
|
"drop_path_rate": 0.0, |
|
}, |
|
pretrained=True, |
|
tuning_method="lora", |
|
tuning_kwargs={"r": 8}, |
|
abs_pos_embed=False, |
|
product_quant=1, |
|
): |
|
super().__init__() |
|
|
|
assert model_name in [ |
|
"vit_small_patch14_dinov2.lvd142m", |
|
"vit_base_patch14_dinov2.lvd142m", |
|
"vit_large_patch14_dinov2.lvd142m", |
|
"vit_giant_patch14_dinov2.lvd142m", |
|
"vit_small_patch14_reg4_dinov2.lvd142m", |
|
"vit_base_patch14_reg4_dinov2.lvd142m", |
|
"vit_large_patch14_reg4_dinov2.lvd142m", |
|
"vit_giant_patch14_reg4_dinov2.lvd142m", |
|
], f"{model_name} not found" |
|
|
|
|
|
self.num_latent_tokens = num_latent_tokens |
|
self.use_attn_mask = use_attn_mask |
|
self.product_quant = product_quant |
|
|
|
|
|
model = create_model(model_name, pretrained=pretrained, **model_kwargs) |
|
|
|
|
|
self.embed_dim = model.embed_dim |
|
|
|
self.num_img_tokens = model.patch_embed.num_patches |
|
self.num_prefix_tokens = model.num_prefix_tokens |
|
self.abs_pos_embed = abs_pos_embed |
|
|
|
|
|
if tuning_method == "full": |
|
|
|
self.model = model |
|
elif tuning_method == "lora": |
|
|
|
|
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d", |
|
modules_to_save=["norm"], |
|
**tuning_kwargs, |
|
) |
|
|
|
self.model = peft.get_peft_model(model, config) |
|
|
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "lora_unfreeze_patch_embed": |
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d", |
|
modules_to_save=["patch_embed.proj", "patch_embed.norm", "norm"], |
|
**tuning_kwargs, |
|
) |
|
self.model = peft.get_peft_model(model, config) |
|
|
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "lat_lora": |
|
from models.peft_models.lora import LatentLoRALinear |
|
|
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d|.*\.qkv|.*\.proj", |
|
modules_to_save=["norm"], |
|
**tuning_kwargs, |
|
) |
|
config._register_custom_module({nn.Linear: LatentLoRALinear}) |
|
self.model = peft.get_peft_model(model, config) |
|
self.use_attn_mask = True |
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "frozen": |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
self.model = model |
|
|
|
if self.num_latent_tokens: |
|
|
|
self.latent_tokens = nn.Parameter( |
|
torch.zeros(1, self.num_latent_tokens, model.embed_dim) |
|
) |
|
nn.init.normal_(self.latent_tokens, std=1e-6) |
|
|
|
if self.abs_pos_embed: |
|
if self.product_quant > 1: |
|
self.lvl_embed = nn.Embedding( |
|
1 + self.product_quant, model.embed_dim |
|
) |
|
patch_size = model_kwargs["patch_size"] |
|
nn.init.trunc_normal_( |
|
self.lvl_embed.weight.data, |
|
mean=0, |
|
std=math.sqrt(1 / model.embed_dim / 3), |
|
) |
|
lvl1LC = torch.cat( |
|
[ |
|
torch.full((patch_size * patch_size + 1,), 0), |
|
] |
|
+ [ |
|
torch.full( |
|
(self.num_latent_tokens // self.product_quant,), i + 1 |
|
) |
|
for i in range(self.product_quant) |
|
] |
|
).view(1, -1) |
|
else: |
|
self.lvl_embed = nn.Embedding(2, model.embed_dim) |
|
patch_size = model_kwargs["patch_size"] |
|
nn.init.trunc_normal_( |
|
self.lvl_embed.weight.data, |
|
mean=0, |
|
std=math.sqrt(1 / model.embed_dim / 3), |
|
) |
|
lvl1LC = torch.cat( |
|
[ |
|
torch.full((patch_size * patch_size + 1,), 0), |
|
torch.full((self.num_latent_tokens,), 1), |
|
] |
|
).view(1, -1) |
|
self.register_buffer("lvl1LC", lvl1LC) |
|
else: |
|
self.latent_pos_embed = nn.Parameter( |
|
torch.zeros(1, self.num_latent_tokens, model.embed_dim) |
|
) |
|
trunc_normal_(self.latent_pos_embed, std=0.02) |
|
|
|
if self.use_attn_mask: |
|
|
|
total_length = ( |
|
self.num_img_tokens |
|
+ self.num_latent_tokens |
|
+ self.num_prefix_tokens |
|
) |
|
attn_mask = torch.zeros((total_length, total_length)) |
|
attn_mask[ |
|
: self.num_prefix_tokens + self.num_img_tokens, |
|
-self.num_latent_tokens :, |
|
] = -torch.inf |
|
attn_mask = attn_mask.view(1, 1, total_length, total_length) |
|
print(attn_mask) |
|
self.register_buffer("attn_mask", attn_mask) |
|
|
|
def finetine(self, tuning_method, tuning_kwargs={"r": 8}): |
|
if tuning_method == "full": |
|
return |
|
elif tuning_method == "lora": |
|
|
|
|
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d", |
|
modules_to_save=["norm"], |
|
**tuning_kwargs, |
|
) |
|
|
|
self.model = peft.get_peft_model(self.model, config) |
|
|
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "lora_unfreeze_patch_embed": |
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d", |
|
modules_to_save=["patch_embed.proj", "patch_embed.norm", "norm"], |
|
**tuning_kwargs, |
|
) |
|
self.model = peft.get_peft_model(self.model, config) |
|
|
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "lat_lora": |
|
from models.peft_models.lora import LatentLoRALinear |
|
|
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d|.*\.qkv|.*\.proj", |
|
modules_to_save=["norm"], |
|
**tuning_kwargs, |
|
) |
|
config._register_custom_module({nn.Linear: LatentLoRALinear}) |
|
self.model = peft.get_peft_model(self.model, config) |
|
self.use_attn_mask = True |
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "frozen": |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def no_weight_decay(self): |
|
return [ |
|
"model.pos_embed", |
|
"model.cls_token", |
|
"model.dist_token", |
|
"latent_tokens", |
|
"latent_pos_embed", |
|
] |
|
|
|
def forward(self, x, masks=None): |
|
|
|
|
|
x = self.model.patch_embed(x) |
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
x = self.model._pos_embed(x) |
|
x = self.model.patch_drop(x) |
|
|
|
if self.num_latent_tokens: |
|
|
|
z = self.latent_tokens.expand(x.size(0), -1, -1) |
|
if self.abs_pos_embed: |
|
if self.product_quant > 1: |
|
H, W = int( |
|
math.sqrt(self.num_latent_tokens // self.product_quant) |
|
), int(math.sqrt(self.num_latent_tokens // self.product_quant)) |
|
assert H * W == self.num_latent_tokens // self.product_quant |
|
z = z.view(x.size(0), self.product_quant * H, W, -1) |
|
z_list = z.chunk(chunks=self.product_quant, dim=1) |
|
z_list = [ |
|
self.model._pos_embed(z)[ |
|
:, |
|
1:, |
|
] |
|
for z in z_list |
|
] |
|
x = torch.cat( |
|
[ |
|
x, |
|
] |
|
+ z_list, |
|
dim=1, |
|
) |
|
x += self.lvl_embed(self.lvl1LC.expand(x.size(0), -1)) |
|
else: |
|
H, W = int(math.sqrt(self.num_latent_tokens)), int( |
|
math.sqrt(self.num_latent_tokens) |
|
) |
|
assert H * W == self.num_latent_tokens |
|
z = z.view(x.size(0), H, W, -1) |
|
z = self.model._pos_embed(z)[ |
|
:, |
|
1:, |
|
] |
|
x = torch.cat([x, z], dim=1) |
|
x += self.lvl_embed(self.lvl1LC.expand(x.size(0), -1)) |
|
else: |
|
x = torch.cat([x, z + self.latent_pos_embed], dim=1) |
|
|
|
temp = x.new_ones(8, 8) |
|
main_type = torch.matmul(temp, temp).dtype |
|
x = x.to(main_type) |
|
|
|
|
|
x = self.model.norm_pre(x) |
|
|
|
|
|
if self.use_attn_mask: |
|
for blk in self.model.blocks: |
|
x = blk(x, self.attn_mask) |
|
else: |
|
x = self.model.blocks(x) |
|
x = self.model.norm(x) |
|
|
|
if self.num_latent_tokens: |
|
|
|
out = x[:, -self.num_latent_tokens :] |
|
else: |
|
|
|
out = x[:, self.num_prefix_tokens :] |
|
return out |
|
|
|
|
|
class DINOv2Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels=3, |
|
model_name="vit_small_patch14_dinov2.lvd142m", |
|
model_kwargs={"img_size": 224, "patch_size": 14, "drop_path_rate": 0.0}, |
|
pretrained=True, |
|
tuning_method="lora", |
|
tuning_kwargs={"r": 8}, |
|
num_latent_tokens=32, |
|
to_pixel="linear", |
|
use_rope=False, |
|
cond_latent=False, |
|
abs_pos_embed=False, |
|
): |
|
super().__init__() |
|
|
|
assert model_name in [ |
|
"vit_small_patch14_dinov2.lvd142m", |
|
"vit_base_patch14_dinov2.lvd142m", |
|
"vit_large_patch14_dinov2.lvd142m", |
|
"vit_giant_patch14_dinov2.lvd142m", |
|
"vit_small_patch14_reg4_dinov2.lvd142m", |
|
"vit_base_patch14_reg4_dinov2.lvd142m", |
|
"vit_large_patch14_reg4_dinov2.lvd142m", |
|
"vit_giant_patch14_reg4_dinov2.lvd142m", |
|
] |
|
|
|
|
|
if use_rope: |
|
print("using RoPEAttention") |
|
attn_layer = RoPEAttention |
|
else: |
|
attn_layer = Attention |
|
|
|
model_kwargs["num_latent_tokens"] = num_latent_tokens |
|
model_kwargs["attn_layer"] = attn_layer |
|
model = create_model(model_name, pretrained=pretrained, **model_kwargs) |
|
self.use_rope = use_rope |
|
self.embed_dim = model.embed_dim |
|
|
|
self.num_img_tokens = model.patch_embed.num_patches |
|
self.num_prefix_tokens = model.num_prefix_tokens |
|
self.num_latent_tokens = num_latent_tokens |
|
|
|
self.abs_pos_embed = abs_pos_embed |
|
|
|
|
|
|
|
|
|
|
|
if tuning_method == "full": |
|
|
|
self.model = model |
|
elif tuning_method == "lora": |
|
|
|
|
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d", |
|
modules_to_save=["norm"], |
|
**tuning_kwargs, |
|
) |
|
self.model = peft.get_peft_model(model, config) |
|
|
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "frozen": |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, model.embed_dim)) |
|
|
|
nn.init.normal_(self.mask_token, std=1e-6) |
|
|
|
|
|
if not self.use_rope: |
|
if self.abs_pos_embed: |
|
self.lvl_embed = nn.Embedding(2, model.embed_dim) |
|
patch_size = model_kwargs["patch_size"] |
|
nn.init.trunc_normal_( |
|
self.lvl_embed.weight.data, |
|
mean=0, |
|
std=math.sqrt(1 / model.embed_dim / 3), |
|
) |
|
lvl1LC = torch.cat( |
|
[ |
|
torch.full((patch_size * patch_size + 1,), 0), |
|
torch.full((self.num_latent_tokens + 1,), 1), |
|
] |
|
).view(1, -1) |
|
self.register_buffer("lvl1LC", lvl1LC) |
|
else: |
|
self.latent_pos_embed = nn.Parameter( |
|
torch.zeros(1, self.num_latent_tokens, model.embed_dim) |
|
) |
|
trunc_normal_(self.latent_pos_embed, std=0.02) |
|
|
|
|
|
|
|
|
|
|
|
self.to_pixel = ToPixel( |
|
to_pixel=to_pixel, |
|
img_size=model_kwargs["img_size"], |
|
in_channels=in_channels, |
|
in_dim=model.embed_dim, |
|
patch_size=model_kwargs["patch_size"], |
|
) |
|
|
|
|
|
self.cond_latent = cond_latent |
|
if self.cond_latent: |
|
self.mlp1 = Mlp(model.embed_dim, model.embed_dim, norm_layer=nn.LayerNorm) |
|
self.mlp2 = Mlp(model.embed_dim, model.embed_dim, norm_layer=nn.LayerNorm) |
|
self.norm1 = nn.LayerNorm(model.embed_dim) |
|
|
|
del self.model.patch_embed.proj.bias |
|
del self.model.patch_embed.proj.weight |
|
|
|
def finetine(self, tuning_method, tuning_kwargs={"r": 8}): |
|
if tuning_method == "full": |
|
|
|
return |
|
elif tuning_method == "lora": |
|
|
|
|
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d", |
|
modules_to_save=["norm"], |
|
**tuning_kwargs, |
|
) |
|
self.model = peft.get_peft_model(self.model, config) |
|
|
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "frozen": |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def no_weight_decay(self): |
|
return [ |
|
"model.pos_embed", |
|
"model.cls_token", |
|
"model.dist_token", |
|
"mask_token", |
|
"latent_pos_embed", |
|
] |
|
|
|
@property |
|
def last_layer(self): |
|
return self.to_pixel.model.weight |
|
|
|
def forward(self, z): |
|
|
|
|
|
x = self.mask_token.expand(z.size(0), self.num_img_tokens, -1) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
if not self.use_rope: |
|
x = self.model._pos_embed(x) |
|
|
|
if self.cond_latent: |
|
ffnout = x + self.mlp1(torch.mean(z.float(), dim=1, keepdim=True)) |
|
x = x + self.mlp2(self.norm1(ffnout)) |
|
if self.abs_pos_embed: |
|
H, W = int(math.sqrt(self.num_latent_tokens)), int( |
|
math.sqrt(self.num_latent_tokens) |
|
) |
|
assert H * W == self.num_latent_tokens |
|
z = z.view(x.size(0), H, W, -1) |
|
z = self.model._pos_embed(z) |
|
else: |
|
z = z + self.latent_pos_embed |
|
else: |
|
to_cat = [] |
|
if self.model.cls_token is not None: |
|
to_cat.append(self.model.cls_token.expand(x.shape[0], -1, -1)) |
|
if self.model.reg_token is not None: |
|
to_cat.append(self.model.reg_token.expand(x.shape[0], -1, -1)) |
|
x = torch.cat(to_cat + [x], dim=1) |
|
x = self.model.patch_drop(x) |
|
|
|
x = torch.cat([x, z], dim=1) |
|
if self.abs_pos_embed: |
|
x += self.lvl_embed(self.lvl1LC.expand(x.size(0), -1)) |
|
|
|
temp = x.new_ones(8, 8) |
|
main_type = torch.matmul(temp, temp).dtype |
|
x = x.to(main_type) |
|
|
|
x = self.model.norm_pre(x) |
|
|
|
|
|
x = self.model.blocks(x) |
|
x = self.model.norm(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x[:, self.num_prefix_tokens : self.num_img_tokens + self.num_prefix_tokens] |
|
|
|
out = self.to_pixel(x) |
|
|
|
return out |
|
|
|
|
|
class DINOv2Decoder_(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels=3, |
|
model_name="vit_small_patch14_dinov2.lvd142m", |
|
model_kwargs={"img_size": 224, "patch_size": 14, "drop_path_rate": 0.0}, |
|
pretrained=True, |
|
tuning_method="lora", |
|
tuning_kwargs={"r": 8}, |
|
to_pixel="linear", |
|
use_rope=False, |
|
cond_latent=False, |
|
): |
|
super().__init__() |
|
|
|
assert model_name in [ |
|
"vit_small_patch14_dinov2.lvd142m", |
|
"vit_base_patch14_dinov2.lvd142m", |
|
"vit_large_patch14_dinov2.lvd142m", |
|
"vit_giant_patch14_dinov2.lvd142m", |
|
"vit_small_patch14_reg4_dinov2.lvd142m", |
|
"vit_base_patch14_reg4_dinov2.lvd142m", |
|
"vit_large_patch14_reg4_dinov2.lvd142m", |
|
"vit_giant_patch14_reg4_dinov2.lvd142m", |
|
] |
|
|
|
|
|
if use_rope: |
|
print("using RoPEAttention") |
|
attn_layer = RoPEAttention |
|
else: |
|
attn_layer = Attention |
|
|
|
model_kwargs["attn_layer"] = attn_layer |
|
model = create_model(model_name, pretrained=pretrained, **model_kwargs) |
|
self.use_rope = use_rope |
|
self.embed_dim = model.embed_dim |
|
|
|
self.num_img_tokens = model.patch_embed.num_patches |
|
self.num_prefix_tokens = model.num_prefix_tokens |
|
|
|
|
|
|
|
|
|
|
|
if tuning_method == "full": |
|
|
|
self.model = model |
|
elif tuning_method == "lora": |
|
|
|
|
|
config = peft.LoraConfig( |
|
target_modules=r".*\.mlp\.fc\d", |
|
modules_to_save=["norm"], |
|
**tuning_kwargs, |
|
) |
|
self.model = peft.get_peft_model(model, config) |
|
|
|
self.model.print_trainable_parameters() |
|
elif tuning_method == "frozen": |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.to_pixel = ToPixel( |
|
to_pixel=to_pixel, |
|
img_size=model_kwargs["img_size"], |
|
in_channels=in_channels, |
|
in_dim=model.embed_dim, |
|
patch_size=model_kwargs["patch_size"], |
|
) |
|
|
|
|
|
self.cond_latent = cond_latent |
|
if self.cond_latent: |
|
self.mlp1 = Mlp(model.embed_dim, model.embed_dim, norm_layer=nn.LayerNorm) |
|
self.mlp2 = Mlp(model.embed_dim, model.embed_dim, norm_layer=nn.LayerNorm) |
|
self.norm1 = nn.LayerNorm(model.embed_dim) |
|
|
|
def no_weight_decay(self): |
|
return [ |
|
"model.pos_embed", |
|
"model.cls_token", |
|
"model.dist_token", |
|
"mask_token", |
|
"latent_pos_embed", |
|
] |
|
|
|
def forward(self, x): |
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
x = self.model._pos_embed(x) |
|
x = self.model.patch_drop(x) |
|
|
|
temp = x.new_ones(8, 8) |
|
main_type = torch.matmul(temp, temp).dtype |
|
x = x.to(main_type) |
|
|
|
x = self.model.norm_pre(x) |
|
|
|
|
|
x = self.model.blocks(x) |
|
x = self.model.norm(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x[:, self.num_prefix_tokens : self.num_img_tokens + self.num_prefix_tokens] |
|
|
|
out = self.to_pixel(x) |
|
|
|
return out |
|
|
|
|
|
if __name__ == "__main__": |
|
encoder = DINOv2Encoder( |
|
model_name="vit_small_patch14_dinov2.lvd142m", |
|
model_kwargs={"img_size": 256, "patch_size": 16, "drop_path_rate": 0.0}, |
|
tuning_method="lat_lora", |
|
tuning_kwargs={"r": 8}, |
|
num_latent_tokens=32, |
|
) |
|
decoder = DINOv2Decoder( |
|
model_name="vit_small_patch14_dinov2.lvd142m", |
|
model_kwargs={"img_size": 256, "patch_size": 16, "drop_path_rate": 0.0}, |
|
tuning_method="full", |
|
tuning_kwargs={"r": 8}, |
|
num_latent_tokens=32, |
|
use_rope=True, |
|
) |
|
x = torch.randn(1, 3, 256, 256) |
|
out = encoder(x) |
|
out = decoder(out) |
|
print(out.shape) |
|
|