import os from typing import Dict import torch from PIL import Image from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer import huggingface_hub from hpsv2.utils import root_path, hps_version_map class HPSMetric: def __init__(self): self.hps_version = "v2.1" self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model_dict = {} self._initialize_model() def _initialize_model(self): if not self.model_dict: model, preprocess_train, preprocess_val = create_model_and_transforms( 'ViT-H-14', 'laion2B-s32B-b79K', precision='amp', device=self.device, jit=False, force_quick_gelu=False, force_custom_text=False, force_patch_dropout=False, force_image_size=None, pretrained_image=False, image_mean=None, image_std=None, light_augmentation=True, aug_cfg={}, output_dict=True, with_score_predictor=False, with_region_predictor=False ) self.model_dict['model'] = model self.model_dict['preprocess_val'] = preprocess_val # Load checkpoint if not os.path.exists(root_path): os.makedirs(root_path) cp = huggingface_hub.hf_hub_download("xswu/HPSv2", hps_version_map[self.hps_version]) checkpoint = torch.load(cp, map_location=self.device) model.load_state_dict(checkpoint['state_dict']) self.tokenizer = get_tokenizer('ViT-H-14') model = model.to(self.device) model.eval() @property def name(self) -> str: return "hps" def compute_score( self, image: Image.Image, prompt: str, ) -> Dict[str, float]: model = self.model_dict['model'] preprocess_val = self.model_dict['preprocess_val'] with torch.no_grad(): # Process the image image_tensor = preprocess_val(image).unsqueeze(0).to(device=self.device, non_blocking=True) # Process the prompt text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True) # Calculate the HPS with torch.cuda.amp.autocast(): outputs = model(image_tensor, text) image_features, text_features = outputs["image_features"], outputs["text_features"] logits_per_image = image_features @ text_features.T hps_score = torch.diagonal(logits_per_image).cpu().numpy() return {"hps": float(hps_score[0])}