KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.model import BaseModel
from torch import nn
from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES,
IMAGENET_SIMPLE_CATEGORIES)
from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT,
OPENAI_IMAGENET_PROMPT_SUB)
CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES]
PROTOTYPE_MAP = {
'imagenet': IMAGENET_SIMPLE_CATEGORIES,
'cifar100': CIFAR100_CATEGORIES,
}
PROMPT_MAP = {
'openai_imagenet': OPENAI_IMAGENET_PROMPT,
'openai_cifar100': OPENAI_CIFAR100_PROMPT,
'vanilla': [lambda c: f'a photo of a {c}'],
'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB
}
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class CLIP(BaseModel):
"""The implementation of `CLIP <https://arxiv.org/abs/2103.00020>`_.
Args:
vision_backbone (dict): Config dict for vision backbone.
text_backbone (dict): Config dict for text backbone.
tokenizer (dict): Config dict for text tokenizer.
proj_dim (int): Projection dimension for similarity computation.
text_prototype (str): Text prototype, which can be a key in
`PROTOTYPE_MAP` or list of text.
text_prompt (str): The prompt for text prototype.
Defaults to 'vanilla',which refers to "a photo of {cls}".
context_length (int): The context length to use. Defaults to 77.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"MultiModalDataPreprocessor" as type.
See :class:`MultiModalDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): The config to control the initialization.
Defaults to None.
"""
def __init__(self,
vision_backbone: dict,
projection: dict,
text_backbone: dict,
tokenizer: dict,
vocab_size: int,
transformer_width: int,
proj_dim: int,
context_length: int = 77,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.context_length = context_length
# build the vision transformer
self.visual = MODELS.build(vision_backbone)
# build the visual projection
self.visual_proj = MODELS.build(projection)
# build attn_mask for casual-attn
text_backbone['attn_mask'] = self.build_attention_mask()
# build the text transformer
self.transformer = MODELS.build(text_backbone)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(
torch.empty(transformer_width, proj_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
self.tokenizer = TOKENIZER.build(tokenizer)
self.tokenizer.vocab = self.tokenizer.get_vocab(
) # CLIPTokenizer has no attribute named 'vocab', so manually
def initialize_parameters(self) -> None:
"""Initialize the parameters.
The pretrained weight will override the initialized parameters by this
function.
"""
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(
self.text_projection, std=self.transformer.width**-0.5)
def build_attention_mask(self):
# lazily create causal attention mask,
# with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(
self,
images: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'predict',
**kwargs,
):
"""The unified entry for a forward process in both training and test.
The method accepts the following modes:
- "predict": Forward and return a list of data samples contain the
predict results.
Args:
images (torch.Tensor): the preprocessed image tensor of shape
``(N, C, H, W)``.
data_samples (List[DataSample], optional): The annotation data
of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'predict'.
"""
if mode == 'predict':
return self.predict(images, data_samples, **kwargs)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor:
"""The function to extract image latent features."""
return self.visual_proj(self.visual(images))[0]
def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor:
"""The function to extract text latent features."""
x = self.token_embedding(texts) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)[0]
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding
# (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]),
texts.argmax(dim=-1)] @ self.text_projection
return x
def extract_feat(
self, images: torch.Tensor,
texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
"""The function to extract image and text latent features, the input
image or text can not both be None."""
assert images is not None or texts is not None, \
'text and image cannot both be None!'
if images is None:
return self.extract_text_feat(texts)
elif texts is None:
return self.extract_image_feat(images)
image_features = self.extract_image_feat(images)
text_features = self.extract_text_feat(texts)
image_features = image_features / image_features.norm(
dim=-1, keepdim=True)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
return image_features, text_features
def compute_similarity(self, images, texts):
"""Extract images and texts features and compute cosine similarity."""
image_features, text_features = self.extract_feat(
images=images, texts=texts)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape (N, N)
return logits_per_image, logits_per_text
@abstractmethod
def predict(self,
images: torch.Tensor,
data_samples: DataSample = None) -> DataSample:
raise NotImplementedError
def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor:
"""Returns the tokenized representation of given input string(s)
Args:
texts (Union[str, List[str]]): An input string or a list of input
strings to tokenize
context_length (int): The context length to use. Defaults to 52.
Returns:
torch.Tensor: Resulting tokens.
"""
if isinstance(texts, str):
texts = [texts]
all_tokens = []
for text in texts:
# adapt the text to Chinese BERT vocab
# text = text.lower().replace('“', "\"").replace('”', "\"")
# add special tokens
all_tokens.append(
[self.tokenizer.vocab['<|startoftext|>']
] + # <|startoftext|>代表[CLS] token
self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text))[:self.context_length - 2] +
[self.tokenizer.vocab['<|endoftext|>']])
result = torch.zeros(
len(all_tokens), self.context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
assert len(tokens) <= self.context_length
result[i, :len(tokens)] = torch.tensor(tokens)
return result
@MODELS.register_module()
class CLIPZeroShot(CLIP):
def __init__(
self,
vision_backbone: dict,
projection: dict,
text_backbone: dict,
tokenizer: dict,
vocab_size: int,
transformer_width: int,
proj_dim: int,
context_length: int = 77,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None,
text_prototype: Union[str, List[str]] = 'imagenet',
text_prompt: str = 'vanilla',
):
super(CLIPZeroShot,
self).__init__(vision_backbone, projection, text_backbone,
tokenizer, vocab_size, transformer_width,
proj_dim, context_length, data_preprocessor,
init_cfg)
# for zero-shot classification
if isinstance(text_prototype,
str) and text_prototype in PROTOTYPE_MAP.keys():
self.prototype = PROTOTYPE_MAP[text_prototype]
else:
self.prototype = text_prototype
self.text_prototype_embeds = None
self.prompt = PROMPT_MAP[text_prompt]
def predict(self,
images: torch.Tensor,
data_samples: DataSample = None) -> DataSample:
"""Predict the classes of the input images.
The prediction is for zero-shot classification and the text prototypes
will be prepared in thisfunction.
Args:
images (torch.Tensor): The input images.
data_samples (DataSample): The data samples with information from
dataset.
Returns:
DataSample: The results of prediction.
"""
if self.text_prototype_embeds is None:
self.prepare_text_prototype(device=images.device)
image_features = self.extract_image_feat(images=images)
image_features /= image_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = image_features @ self.text_prototype_embeds.to(
image_features.device) * self.logit_scale.exp()
pred_scores = F.softmax(logits_per_image, dim=1)
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
out_data_samples = []
if data_samples is None:
data_samples = [None for _ in range(pred_scores.size(0))]
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
if data_sample is None:
data_sample = DataSample()
data_sample.set_pred_score(score).set_pred_label(label)
out_data_samples.append(data_sample)
return out_data_samples
def prepare_text_prototype(self, device) -> None:
"""The function to prepare text prototypes with prompt."""
class_embeddings = []
for classname in track_on_main_process(self.prototype,
'Prepare text prototype...'):
# format with class
texts = [prompt(classname) for prompt in self.prompt]
tokenized_texts = self.tokenize(texts)
class_features = self.extract_text_feat(tokenized_texts.to(device))
class_features /= class_features.norm(dim=-1, keepdim=True)
class_feature = class_features.mean(dim=0)
class_feature /= class_feature.norm()
class_embeddings.append(class_feature)
self.text_prototype_embeds = torch.stack(
class_embeddings, dim=1).to(device)