|
import os |
|
import argparse |
|
|
|
import torch |
|
import numpy as np |
|
|
|
|
|
def parse_args(): |
|
|
|
parser = argparse.ArgumentParser("Reparameterize YOLO-World") |
|
parser.add_argument('--model', help='model checkpoints to reparameterize') |
|
parser.add_argument('--out-dir', help='output checkpoints') |
|
parser.add_argument( |
|
'--text-embed', |
|
help='text embeddings to reparameterized into YOLO-World') |
|
parser.add_argument('--conv-neck', |
|
action='store_true', |
|
help='whether using 1x1 conv in RepVL-PAN') |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def convert_head(scale, bias, text_embed): |
|
N, D = text_embed.shape |
|
weight = (text_embed * scale.exp()).view(N, D, 1, 1) |
|
bias = torch.ones(N) * bias |
|
return weight, bias |
|
|
|
|
|
def reparameterize_head(state_dict, embeds): |
|
|
|
cls_layers = [ |
|
'bbox_head.head_module.cls_contrasts.0', |
|
'bbox_head.head_module.cls_contrasts.1', |
|
'bbox_head.head_module.cls_contrasts.2' |
|
] |
|
|
|
for i in range(3): |
|
scale = state_dict[cls_layers[i] + '.logit_scale'] |
|
bias = state_dict[cls_layers[i] + '.bias'] |
|
weight, bias = convert_head(scale, bias, embeds) |
|
state_dict[cls_layers[i] + '.conv.weight'] = weight |
|
state_dict[cls_layers[i] + '.conv.bias'] = bias |
|
del state_dict[cls_layers[i] + '.bias'] |
|
del state_dict[cls_layers[i] + '.logit_scale'] |
|
return state_dict |
|
|
|
|
|
def convert_neck_split_conv(input_state_dict, block_name, text_embeds, |
|
num_heads): |
|
if block_name + '.guide_fc.weight' not in input_state_dict: |
|
return input_state_dict |
|
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight'] |
|
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias'] |
|
guide = text_embeds @ guide_fc_weight.transpose(0, |
|
1) + guide_fc_bias[None, :] |
|
N, D = guide.shape |
|
guide = list(guide.split(D // num_heads, dim=1)) |
|
del input_state_dict[block_name + '.guide_fc.weight'] |
|
del input_state_dict[block_name + '.guide_fc.bias'] |
|
for i in range(num_heads): |
|
input_state_dict[block_name + |
|
f'.guide_convs.{i}.weight'] = guide[i][:, :, None, |
|
None] |
|
return input_state_dict |
|
|
|
|
|
def convert_neck_weight(input_state_dict, block_name, embeds, num_heads): |
|
guide_fc_weight = input_state_dict[block_name + '.guide_fc.weight'] |
|
guide_fc_bias = input_state_dict[block_name + '.guide_fc.bias'] |
|
guide = embeds @ guide_fc_weight.transpose(0, 1) + guide_fc_bias[None, :] |
|
N, D = guide.shape |
|
del input_state_dict[block_name + '.guide_fc.weight'] |
|
del input_state_dict[block_name + '.guide_fc.bias'] |
|
input_state_dict[block_name + '.guide_weight'] = guide.view( |
|
N, D // num_heads, num_heads) |
|
return input_state_dict |
|
|
|
|
|
def reparameterize_neck(state_dict, embeds, type='conv'): |
|
|
|
neck_blocks = [ |
|
'neck.top_down_layers.0.attn_block', |
|
'neck.top_down_layers.1.attn_block', |
|
'neck.bottom_up_layers.0.attn_block', |
|
'neck.bottom_up_layers.1.attn_block' |
|
] |
|
if "neck.top_down_layers.0.attn_block.bias" not in state_dict: |
|
return state_dict |
|
for block in neck_blocks: |
|
num_heads = state_dict[block + '.bias'].shape[0] |
|
if type == 'conv': |
|
convert_neck_split_conv(state_dict, block, embeds, num_heads) |
|
else: |
|
convert_neck_weight(state_dict, block, embeds, num_heads) |
|
return state_dict |
|
|
|
|
|
def main(): |
|
|
|
args = parse_args() |
|
|
|
|
|
model = torch.load(args.model, map_location='cpu') |
|
state_dict = model['state_dict'] |
|
|
|
|
|
embeddings = torch.from_numpy(np.load(args.text_embed)) |
|
|
|
|
|
keys = list(state_dict.keys()) |
|
keys = [x for x in keys if "text_model" not in x] |
|
|
|
state_dict_wo_text = {x: state_dict[x] for x in keys} |
|
print("removing text encoder") |
|
|
|
state_dict_wo_text = reparameterize_head(state_dict_wo_text, embeddings) |
|
print("reparameterizing head") |
|
|
|
if args.conv_neck: |
|
neck_type = "conv" |
|
else: |
|
neck_type = "linear" |
|
|
|
state_dict_wo_text = reparameterize_neck(state_dict_wo_text, embeddings, |
|
neck_type) |
|
|
|
print("reparameterizing neck") |
|
|
|
model['state_dict'] = state_dict_wo_text |
|
|
|
model_name = os.path.basename(args.model) |
|
model_name = model_name.replace('.pth', f'_rep_{neck_type}.pth') |
|
torch.save(model, os.path.join(args.out_dir, model_name)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|