Spaces:
Sleeping
Sleeping
import json | |
import torch.nn as nn | |
from open_clip.factory import * | |
# def create_model_and_transforms( | |
# model_name: str, | |
# config: str, | |
# device: Union[str, torch.device] = 'cpu', | |
# cache_dir: Optional[str] = None, | |
# force_preprocess_cfg: Optional[Dict[str, Any]] = None, | |
# ): | |
# force_preprocess_cfg = force_preprocess_cfg or {} | |
# preprocess_cfg = asdict(PreprocessCfg()) | |
# with open(config, 'r') as f: | |
# config = json.load(f) | |
# checkpoint_path = os.path.join(cache_dir, 'open_clip_pytorch_model.bin') | |
# preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) | |
# model_cfg = config['model_cfg'] | |
# if isinstance(device, str): | |
# device = torch.device(device) | |
# print(f'Loaded {model_name} model config.') | |
# # load pretrained weights for HF text model IFF no CLIP weights being loaded | |
# model_cfg['text_cfg']['hf_model_pretrained'] = False | |
# model = CustomTextCLIP(**model_cfg) | |
# model.to(device=device) | |
# print(f'Loading pretrained {model_name} weights ({checkpoint_path}).') | |
# load_checkpoint(model, checkpoint_path) | |
# # set image preprocessing configuration in model attributes for convenience | |
# if getattr(model.visual, 'image_size', None) is not None: | |
# # use image_size set on model creation (via config or force_image_size arg) | |
# force_preprocess_cfg['size'] = model.visual.image_size | |
# set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) | |
# pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) | |
# preprocess_train = image_transform_v2( | |
# pp_cfg, | |
# is_train=True, | |
# aug_cfg=None, | |
# ) | |
# preprocess_val = image_transform_v2( | |
# pp_cfg, | |
# is_train=False, | |
# ) | |
# return model, preprocess_train, preprocess_val | |
def get_my_tokenizer( | |
config: str, | |
context_length: Optional[int] = None, | |
**kwargs, | |
): | |
with open(config, 'r') as f: | |
config = json.load(f) | |
text_config = config['model_cfg']['text_cfg'] | |
if 'tokenizer_kwargs' in text_config: | |
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs) | |
else: | |
tokenizer_kwargs = kwargs | |
if context_length is None: | |
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) | |
if 'hf_tokenizer_name' in text_config: | |
tokenizer = HFTokenizer( | |
text_config['hf_tokenizer_name'], | |
context_length=context_length, | |
**tokenizer_kwargs, | |
) | |
else: | |
tokenizer = SimpleTokenizer( | |
context_length=context_length, | |
**tokenizer_kwargs, | |
) | |
return tokenizer | |
class BiomedCLIPTextEncoder(nn.Module): | |
def __init__(self, device: torch.device) -> None: | |
super().__init__() | |
# self.model, _, _ = create_model_and_transforms( | |
# model_name='hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', | |
# # config='./ckpt/BiomedCLIP/open_clip_config.json', | |
# cache_dir='./ckpt/BiomedCLIP/' | |
# ) | |
self.model, _, _ = create_model_and_transforms('hf-hub:hsiangyualex/biomedclip4imc') | |
self.model.eval() | |
self.model.to(device) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
# self.tokenizer = get_my_tokenizer(config='./ckpt/BiomedCLIP/open_clip_config.json') | |
self.tokenizer = get_tokenizer('hf-hub:hsiangyualex/biomedclip4imc') | |
self.device = device | |
def forward(self, prompts): | |
""" | |
Args: | |
prompts: a series of protein names | |
""" | |
prompts = [f"An imaging mass cytometry staining image of {prompt} protein." for prompt in prompts] | |
prompts = self.tokenizer(prompts).to(self.device) | |
text_features = self.model.encode_text(prompts).detach() | |
return text_features | |