davidberenstein1957's picture
refactor: improve code formatting and organization across multiple API and benchmark files
34046e2
raw
history blame
1.76 kB
import json
import os
from pathlib import Path
from typing import Dict, Iterator, List, Tuple
import huggingface_hub
class HPSPrompts:
def __init__(self):
super().__init__()
self.hps_prompt_files = [
"anime.json",
"concept-art.json",
"paintings.json",
"photo.json",
]
self._download_benchmark_prompts()
self.prompts: Dict[str, str] = {}
self._size = 0
for file in self.hps_prompt_files:
category = file.replace(".json", "")
with open(os.path.join("downloads/hps", file), "r") as f:
prompts = json.load(f)
for i, prompt in enumerate(prompts):
if i == 100:
break
filename = f"{category}_{i:03d}.png"
self.prompts[filename] = prompt
self._size += 1
def __iter__(self) -> Iterator[Tuple[str, Path]]:
for filename, prompt in self.prompts.items():
yield prompt, Path(filename)
@property
def name(self) -> str:
return "hps"
@property
def size(self) -> int:
return self._size
def _download_benchmark_prompts(self) -> None:
folder_name = Path("downloads/hps")
folder_name.mkdir(parents=True, exist_ok=True)
for file in self.hps_prompt_files:
file_name = huggingface_hub.hf_hub_download(
"zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset"
)
if not os.path.exists(os.path.join(folder_name, file)):
os.symlink(file_name, os.path.join(folder_name, file))
@property
def metrics(self) -> List[str]:
return ["hps"]