File size: 3,989 Bytes
f97a499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import torch.nn as nn
from open_clip.factory import *


# def create_model_and_transforms(
#         model_name: str,
#         config: str,
#         device: Union[str, torch.device] = 'cpu',
#         cache_dir: Optional[str] = None,
#         force_preprocess_cfg: Optional[Dict[str, Any]] = None,
# ):
#     force_preprocess_cfg = force_preprocess_cfg or {}
#     preprocess_cfg = asdict(PreprocessCfg())
#     with open(config, 'r') as f:
#         config = json.load(f)

#     checkpoint_path = os.path.join(cache_dir, 'open_clip_pytorch_model.bin')
#     preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
#     model_cfg = config['model_cfg']

#     if isinstance(device, str):
#         device = torch.device(device)
#     print(f'Loaded {model_name} model config.')

#     # load pretrained weights for HF text model IFF no CLIP weights being loaded
#     model_cfg['text_cfg']['hf_model_pretrained'] = False

#     model = CustomTextCLIP(**model_cfg)
#     model.to(device=device)

#     print(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
#     load_checkpoint(model, checkpoint_path)

#     # set image preprocessing configuration in model attributes for convenience
#     if getattr(model.visual, 'image_size', None) is not None:
#         # use image_size set on model creation (via config or force_image_size arg)
#         force_preprocess_cfg['size'] = model.visual.image_size
#     set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))

#     pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)

#     preprocess_train = image_transform_v2(
#         pp_cfg,
#         is_train=True,
#         aug_cfg=None,
#     )
#     preprocess_val = image_transform_v2(
#         pp_cfg,
#         is_train=False,
#     )

#     return model, preprocess_train, preprocess_val


def get_my_tokenizer(
        config: str,
        context_length: Optional[int] = None,
        **kwargs,
):
    with open(config, 'r') as f:
        config = json.load(f)

    text_config = config['model_cfg']['text_cfg']
    if 'tokenizer_kwargs' in text_config:
        tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
    else:
        tokenizer_kwargs = kwargs

    if context_length is None:
        context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)

    if 'hf_tokenizer_name' in text_config:
        tokenizer = HFTokenizer(
            text_config['hf_tokenizer_name'],
            context_length=context_length,
            **tokenizer_kwargs,
        )
    else:
        tokenizer = SimpleTokenizer(
            context_length=context_length,
            **tokenizer_kwargs,
        )

    return tokenizer


class BiomedCLIPTextEncoder(nn.Module):
    def __init__(self, device: torch.device) -> None:
        super().__init__()
        # self.model, _, _ = create_model_and_transforms(
        #     model_name='hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', 
        #     # config='./ckpt/BiomedCLIP/open_clip_config.json', 
        #     cache_dir='./ckpt/BiomedCLIP/'
        #     )
        self.model, _, _ = create_model_and_transforms('hf-hub:hsiangyualex/biomedclip4imc')
        self.model.eval()
        self.model.to(device)
        for param in self.model.parameters():
            param.requires_grad = False
        # self.tokenizer = get_my_tokenizer(config='./ckpt/BiomedCLIP/open_clip_config.json')
        self.tokenizer = get_tokenizer('hf-hub:hsiangyualex/biomedclip4imc')
        self.device = device
    
    @torch.no_grad()
    def forward(self, prompts):
        """
        Args:
            prompts: a series of protein names
        """
        prompts = [f"An imaging mass cytometry staining image of {prompt} protein." for prompt in prompts]
        prompts = self.tokenizer(prompts).to(self.device)
        text_features = self.model.encode_text(prompts).detach()
        return text_features