import json import os from pathlib import Path from typing import Dict, Iterator, List, Tuple import huggingface_hub class HPSPrompts: def __init__(self): super().__init__() self.hps_prompt_files = ['anime.json', 'concept-art.json', 'paintings.json', 'photo.json'] self._download_benchmark_prompts() self.prompts: Dict[str, str] = {} self._size = 0 for file in self.hps_prompt_files: category = file.replace('.json', '') with open(os.path.join('downloads/hps', file), 'r') as f: prompts = json.load(f) for i, prompt in enumerate(prompts): if i == 100: break filename = f"{category}_{i:03d}.png" self.prompts[filename] = prompt self._size += 1 def __iter__(self) -> Iterator[Tuple[str, Path]]: for filename, prompt in self.prompts.items(): yield prompt, Path(filename) @property def name(self) -> str: return "hps" @property def size(self) -> int: return self._size def _download_benchmark_prompts(self) -> None: folder_name = Path('downloads/hps') folder_name.mkdir(parents=True, exist_ok=True) for file in self.hps_prompt_files: file_name = huggingface_hub.hf_hub_download("zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset") if not os.path.exists(os.path.join(folder_name, file)): os.symlink(file_name, os.path.join(folder_name, file)) @property def metrics(self) -> List[str]: return ["hps"]