Mbi2Spi / models /modules /biomedclip.py
hsiangyualex's picture
Upload 64 files
f97a499 verified
raw
history blame
3.99 kB
import json
import torch.nn as nn
from open_clip.factory import *
# def create_model_and_transforms(
# model_name: str,
# config: str,
# device: Union[str, torch.device] = 'cpu',
# cache_dir: Optional[str] = None,
# force_preprocess_cfg: Optional[Dict[str, Any]] = None,
# ):
# force_preprocess_cfg = force_preprocess_cfg or {}
# preprocess_cfg = asdict(PreprocessCfg())
# with open(config, 'r') as f:
# config = json.load(f)
# checkpoint_path = os.path.join(cache_dir, 'open_clip_pytorch_model.bin')
# preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
# model_cfg = config['model_cfg']
# if isinstance(device, str):
# device = torch.device(device)
# print(f'Loaded {model_name} model config.')
# # load pretrained weights for HF text model IFF no CLIP weights being loaded
# model_cfg['text_cfg']['hf_model_pretrained'] = False
# model = CustomTextCLIP(**model_cfg)
# model.to(device=device)
# print(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
# load_checkpoint(model, checkpoint_path)
# # set image preprocessing configuration in model attributes for convenience
# if getattr(model.visual, 'image_size', None) is not None:
# # use image_size set on model creation (via config or force_image_size arg)
# force_preprocess_cfg['size'] = model.visual.image_size
# set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
# pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
# preprocess_train = image_transform_v2(
# pp_cfg,
# is_train=True,
# aug_cfg=None,
# )
# preprocess_val = image_transform_v2(
# pp_cfg,
# is_train=False,
# )
# return model, preprocess_train, preprocess_val
def get_my_tokenizer(
config: str,
context_length: Optional[int] = None,
**kwargs,
):
with open(config, 'r') as f:
config = json.load(f)
text_config = config['model_cfg']['text_cfg']
if 'tokenizer_kwargs' in text_config:
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
else:
tokenizer_kwargs = kwargs
if context_length is None:
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
if 'hf_tokenizer_name' in text_config:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer(
context_length=context_length,
**tokenizer_kwargs,
)
return tokenizer
class BiomedCLIPTextEncoder(nn.Module):
def __init__(self, device: torch.device) -> None:
super().__init__()
# self.model, _, _ = create_model_and_transforms(
# model_name='hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224',
# # config='./ckpt/BiomedCLIP/open_clip_config.json',
# cache_dir='./ckpt/BiomedCLIP/'
# )
self.model, _, _ = create_model_and_transforms('hf-hub:hsiangyualex/biomedclip4imc')
self.model.eval()
self.model.to(device)
for param in self.model.parameters():
param.requires_grad = False
# self.tokenizer = get_my_tokenizer(config='./ckpt/BiomedCLIP/open_clip_config.json')
self.tokenizer = get_tokenizer('hf-hub:hsiangyualex/biomedclip4imc')
self.device = device
@torch.no_grad()
def forward(self, prompts):
"""
Args:
prompts: a series of protein names
"""
prompts = [f"An imaging mass cytometry staining image of {prompt} protein." for prompt in prompts]
prompts = self.tokenizer(prompts).to(self.device)
text_features = self.model.encode_text(prompts).detach()
return text_features