davidberenstein1957 commited on
Commit
199a7d9
·
1 Parent(s): f5f9b38

chore: update .gitignore, environment.yml, and README; remove nils_installs.txt

Browse files
.gitignore CHANGED
@@ -175,4 +175,5 @@ cython_debug/
175
 
176
  evaluation_results/
177
  images/
178
- hf_cache/
 
 
175
 
176
  evaluation_results/
177
  images/
178
+ hf_cache/
179
+ *.lock
README.md CHANGED
@@ -1,10 +1,28 @@
1
  # InferBench
2
  Evaluate the quality and efficiency of image gen api's.
3
 
4
- Install dependencies with conda like that:
 
 
5
 
 
 
6
  conda env create -f environment.yml
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
8
 
9
  Create .env file with all the credentials you will need.
10
 
 
1
  # InferBench
2
  Evaluate the quality and efficiency of image gen api's.
3
 
4
+ ## Installation
5
+
6
+ ### Install dependencies
7
 
8
+ Install dependencies with conda like that:
9
+ ```
10
  conda env create -f environment.yml
11
+ ```
12
+
13
+ ### Install uv
14
+
15
+ Install uv with pip like that:
16
+
17
+ ```
18
+ uv venv --python 3.12
19
+ ```
20
+
21
+ ```
22
+ uv sync --all-groups
23
+ ```
24
 
25
+ ## Usage
26
 
27
  Create .env file with all the credentials you will need.
28
 
benchmark/draw_bench.py CHANGED
@@ -6,16 +6,16 @@ from datasets import load_dataset
6
 
7
  class DrawBenchPrompts:
8
  def __init__(self):
9
- self.dataset = load_dataset("shunk031/DrawBench")["test"]
10
-
11
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
12
  for i, row in enumerate(self.dataset):
13
  yield row["prompts"], Path(f"{i}.png")
14
-
15
  @property
16
  def name(self) -> str:
17
  return "draw_bench"
18
-
19
  @property
20
  def size(self) -> int:
21
  return len(self.dataset)
 
6
 
7
  class DrawBenchPrompts:
8
  def __init__(self):
9
+ self.dataset = load_dataset("shunk031/DrawBench", split="test[:5]")
10
+
11
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
12
  for i, row in enumerate(self.dataset):
13
  yield row["prompts"], Path(f"{i}.png")
14
+
15
  @property
16
  def name(self) -> str:
17
  return "draw_bench"
18
+
19
  @property
20
  def size(self) -> int:
21
  return len(self.dataset)
benchmark/metrics/arniqa.py CHANGED
@@ -8,20 +8,29 @@ from torchmetrics.image.arniqa import ARNIQA
8
 
9
  class ARNIQAMetric:
10
  def __init__(self):
11
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
12
  self.metric = ARNIQA(
13
  regressor_dataset="koniq10k",
14
  reduction="mean",
15
  normalize=True,
16
- autocast=False
17
  )
18
  self.metric.to(self.device)
 
19
  @property
20
  def name(self) -> str:
21
  return "arniqa"
22
-
23
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
24
- image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
 
 
25
  image_tensor = image_tensor.unsqueeze(0).to(self.device)
26
  score = self.metric(image_tensor)
27
  return {"arniqa": score.item()}
 
8
 
9
  class ARNIQAMetric:
10
  def __init__(self):
11
+ self.device = torch.device(
12
+ "cuda"
13
+ if torch.cuda.is_available()
14
+ else "mps"
15
+ if torch.backends.mps.is_available()
16
+ else "cpu"
17
+ )
18
  self.metric = ARNIQA(
19
  regressor_dataset="koniq10k",
20
  reduction="mean",
21
  normalize=True,
22
+ autocast=False,
23
  )
24
  self.metric.to(self.device)
25
+
26
  @property
27
  def name(self) -> str:
28
  return "arniqa"
29
+
30
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
31
+ image_tensor = (
32
+ torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
33
+ )
34
  image_tensor = image_tensor.unsqueeze(0).to(self.device)
35
  score = self.metric(image_tensor)
36
  return {"arniqa": score.item()}
benchmark/metrics/clip.py CHANGED
@@ -8,14 +8,20 @@ from torchmetrics.multimodal.clip_score import CLIPScore
8
 
9
  class CLIPMetric:
10
  def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
11
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
12
  self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
13
  self.metric.to(self.device)
14
-
15
  @property
16
  def name(self) -> str:
17
  return "clip"
18
-
19
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
20
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
21
  image_tensor = image_tensor.to(self.device)
 
8
 
9
  class CLIPMetric:
10
  def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
11
+ self.device = torch.device(
12
+ "cuda"
13
+ if torch.cuda.is_available()
14
+ else "mps"
15
+ if torch.backends.mps.is_available()
16
+ else "cpu"
17
+ )
18
  self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
19
  self.metric.to(self.device)
20
+
21
  @property
22
  def name(self) -> str:
23
  return "clip"
24
+
25
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
26
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
27
  image_tensor = image_tensor.to(self.device)
benchmark/metrics/clip_iqa.py CHANGED
@@ -8,18 +8,22 @@ from torchmetrics.multimodal import CLIPImageQualityAssessment
8
 
9
  class CLIPIQAMetric:
10
  def __init__(self):
11
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
12
  self.metric = CLIPImageQualityAssessment(
13
- model_name_or_path="clip_iqa",
14
- data_range=255.0,
15
- prompts=("quality",)
16
  )
17
  self.metric.to(self.device)
18
 
19
  @property
20
  def name(self) -> str:
21
  return "clip_iqa"
22
-
23
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
24
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
25
  image_tensor = image_tensor.unsqueeze(0)
 
8
 
9
  class CLIPIQAMetric:
10
  def __init__(self):
11
+ self.device = torch.device(
12
+ "cuda"
13
+ if torch.cuda.is_available()
14
+ else "mps"
15
+ if torch.backends.mps.is_available()
16
+ else "cpu"
17
+ )
18
  self.metric = CLIPImageQualityAssessment(
19
+ model_name_or_path="clip_iqa", data_range=255.0, prompts=("quality",)
 
 
20
  )
21
  self.metric.to(self.device)
22
 
23
  @property
24
  def name(self) -> str:
25
  return "clip_iqa"
26
+
27
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
28
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
29
  image_tensor = image_tensor.unsqueeze(0)
benchmark/metrics/hps.py CHANGED
@@ -1,26 +1,32 @@
1
  import os
2
  from typing import Dict
3
 
 
4
  import torch
5
- from PIL import Image
6
  from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
7
- import huggingface_hub
8
- from hpsv2.utils import root_path, hps_version_map
9
 
10
 
11
  class HPSMetric:
12
  def __init__(self):
13
  self.hps_version = "v2.1"
14
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
15
  self.model_dict = {}
16
  self._initialize_model()
17
-
18
  def _initialize_model(self):
19
  if not self.model_dict:
20
  model, preprocess_train, preprocess_val = create_model_and_transforms(
21
- 'ViT-H-14',
22
- 'laion2B-s32B-b79K',
23
- precision='amp',
24
  device=self.device,
25
  jit=False,
26
  force_quick_gelu=False,
@@ -34,44 +40,53 @@ class HPSMetric:
34
  aug_cfg={},
35
  output_dict=True,
36
  with_score_predictor=False,
37
- with_region_predictor=False
38
  )
39
- self.model_dict['model'] = model
40
- self.model_dict['preprocess_val'] = preprocess_val
41
-
42
  # Load checkpoint
43
  if not os.path.exists(root_path):
44
  os.makedirs(root_path)
45
- cp = huggingface_hub.hf_hub_download("xswu/HPSv2", hps_version_map[self.hps_version])
46
-
 
 
47
  checkpoint = torch.load(cp, map_location=self.device)
48
- model.load_state_dict(checkpoint['state_dict'])
49
- self.tokenizer = get_tokenizer('ViT-H-14')
50
  model = model.to(self.device)
51
  model.eval()
52
-
53
  @property
54
  def name(self) -> str:
55
  return "hps"
56
-
57
  def compute_score(
58
  self,
59
  image: Image.Image,
60
  prompt: str,
61
  ) -> Dict[str, float]:
62
- model = self.model_dict['model']
63
- preprocess_val = self.model_dict['preprocess_val']
64
-
65
  with torch.no_grad():
66
  # Process the image
67
- image_tensor = preprocess_val(image).unsqueeze(0).to(device=self.device, non_blocking=True)
 
 
 
 
68
  # Process the prompt
69
  text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
70
  # Calculate the HPS
71
  with torch.cuda.amp.autocast():
72
  outputs = model(image_tensor, text)
73
- image_features, text_features = outputs["image_features"], outputs["text_features"]
 
 
 
74
  logits_per_image = image_features @ text_features.T
75
  hps_score = torch.diagonal(logits_per_image).cpu().numpy()
76
-
77
  return {"hps": float(hps_score[0])}
 
1
  import os
2
  from typing import Dict
3
 
4
+ import huggingface_hub
5
  import torch
 
6
  from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
7
+ from hpsv2.utils import hps_version_map, root_path
8
+ from PIL import Image
9
 
10
 
11
  class HPSMetric:
12
  def __init__(self):
13
  self.hps_version = "v2.1"
14
+ self.device = torch.device(
15
+ "cuda"
16
+ if torch.cuda.is_available()
17
+ else "mps"
18
+ if torch.backends.mps.is_available()
19
+ else "cpu"
20
+ )
21
  self.model_dict = {}
22
  self._initialize_model()
23
+
24
  def _initialize_model(self):
25
  if not self.model_dict:
26
  model, preprocess_train, preprocess_val = create_model_and_transforms(
27
+ "ViT-H-14",
28
+ "laion2B-s32B-b79K",
29
+ precision="amp",
30
  device=self.device,
31
  jit=False,
32
  force_quick_gelu=False,
 
40
  aug_cfg={},
41
  output_dict=True,
42
  with_score_predictor=False,
43
+ with_region_predictor=False,
44
  )
45
+ self.model_dict["model"] = model
46
+ self.model_dict["preprocess_val"] = preprocess_val
47
+
48
  # Load checkpoint
49
  if not os.path.exists(root_path):
50
  os.makedirs(root_path)
51
+ cp = huggingface_hub.hf_hub_download(
52
+ "xswu/HPSv2", hps_version_map[self.hps_version]
53
+ )
54
+
55
  checkpoint = torch.load(cp, map_location=self.device)
56
+ model.load_state_dict(checkpoint["state_dict"])
57
+ self.tokenizer = get_tokenizer("ViT-H-14")
58
  model = model.to(self.device)
59
  model.eval()
60
+
61
  @property
62
  def name(self) -> str:
63
  return "hps"
64
+
65
  def compute_score(
66
  self,
67
  image: Image.Image,
68
  prompt: str,
69
  ) -> Dict[str, float]:
70
+ model = self.model_dict["model"]
71
+ preprocess_val = self.model_dict["preprocess_val"]
72
+
73
  with torch.no_grad():
74
  # Process the image
75
+ image_tensor = (
76
+ preprocess_val(image)
77
+ .unsqueeze(0)
78
+ .to(device=self.device, non_blocking=True)
79
+ )
80
  # Process the prompt
81
  text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
82
  # Calculate the HPS
83
  with torch.cuda.amp.autocast():
84
  outputs = model(image_tensor, text)
85
+ image_features, text_features = (
86
+ outputs["image_features"],
87
+ outputs["text_features"],
88
+ )
89
  logits_per_image = image_features @ text_features.T
90
  hps_score = torch.diagonal(logits_per_image).cpu().numpy()
91
+
92
  return {"hps": float(hps_score[0])}
benchmark/metrics/image_reward.py CHANGED
@@ -3,17 +3,26 @@ import tempfile
3
  from typing import Dict
4
 
5
  import ImageReward as RM
 
6
  from PIL import Image
7
 
8
 
9
  class ImageRewardMetric:
10
  def __init__(self):
11
- self.model = RM.load("ImageReward-v1.0")
12
-
 
 
 
 
 
 
 
 
13
  @property
14
  def name(self) -> str:
15
  return "image_reward"
16
-
17
  def compute_score(
18
  self,
19
  image: Image.Image,
 
3
  from typing import Dict
4
 
5
  import ImageReward as RM
6
+ import torch
7
  from PIL import Image
8
 
9
 
10
  class ImageRewardMetric:
11
  def __init__(self):
12
+ self.device = torch.device(
13
+ "cuda"
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ if torch.backends.mps.is_available()
17
+ else "cpu"
18
+ )
19
+
20
+ self.model = RM.load("ImageReward-v1.0", device=str(self.device))
21
+
22
  @property
23
  def name(self) -> str:
24
  return "image_reward"
25
+
26
  def compute_score(
27
  self,
28
  image: Image.Image,
benchmark/metrics/vqa.py CHANGED
@@ -2,15 +2,26 @@ from pathlib import Path
2
  from typing import Dict
3
 
4
  import t2v_metrics
 
 
5
 
6
  class VQAMetric:
7
  def __init__(self):
8
- self.metric = t2v_metrics.VQAScore(model="clip-flant5-xxl")
9
-
 
 
 
 
 
 
 
 
 
10
  @property
11
  def name(self) -> str:
12
  return "vqa_score"
13
-
14
  def compute_score(
15
  self,
16
  image_path: Path,
 
2
  from typing import Dict
3
 
4
  import t2v_metrics
5
+ import torch
6
+
7
 
8
  class VQAMetric:
9
  def __init__(self):
10
+ self.device = torch.device(
11
+ "cuda"
12
+ if torch.cuda.is_available()
13
+ else "mps"
14
+ if torch.backends.mps.is_available()
15
+ else "cpu"
16
+ )
17
+ self.metric = t2v_metrics.VQAScore(
18
+ model="clip-flant5-xxl", device=str(self.device)
19
+ )
20
+
21
  @property
22
  def name(self) -> str:
23
  return "vqa_score"
24
+
25
  def compute_score(
26
  self,
27
  image_path: Path,
environment.yml CHANGED
@@ -12,7 +12,7 @@ dependencies:
12
  - tqdm
13
  - pip
14
  - pip:
15
- - datasets>=3.5.0
16
  - fal-client>=0.5.9
17
  - hpsv2>=1.2.0
18
  - huggingface-hub>=0.30.2
 
12
  - tqdm
13
  - pip
14
  - pip:
15
+ - datasets==3.6.0
16
  - fal-client>=0.5.9
17
  - hpsv2>=1.2.0
18
  - huggingface-hub>=0.30.2
nils_installs.txt DELETED
@@ -1,177 +0,0 @@
1
- accelerate==1.7.0
2
- aiohappyeyeballs==2.6.1
3
- aiohttp==3.12.12
4
- aiosignal==1.3.2
5
- annotated-types==0.7.0
6
- antlr4-python3-runtime==4.9.3
7
- anyio==4.9.0
8
- args==0.1.0
9
- asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
10
- attrs==25.3.0
11
- beautifulsoup4==4.13.4
12
- boto3==1.38.33
13
- botocore==1.38.33
14
- braceexpand==0.1.7
15
- Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1749229842835/work
16
- certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1746569525376/work/certifi
17
- cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725560558132/work
18
- charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1746214863626/work
19
- click==8.1.8
20
- clint==0.5.1
21
- clip==0.2.0
22
- colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
23
- comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
24
- contourpy==1.3.2
25
- cycler==0.12.1
26
- datasets==3.6.0
27
- debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321241074/work
28
- decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
29
- diffusers==0.31.0
30
- dill==0.3.8
31
- distro==1.9.0
32
- einops==0.8.1
33
- eval_type_backport==0.2.2
34
- exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
35
- executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
36
- fairscale==0.4.13
37
- fal_client==0.7.0
38
- filelock==3.18.0
39
- fire==0.4.0
40
- fonttools==4.58.2
41
- frozenlist==1.7.0
42
- fsspec==2025.3.0
43
- ftfy==6.3.1
44
- gdown==5.2.0
45
- h11==0.16.0
46
- h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
47
- hf-xet==1.1.3
48
- hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
49
- hpsv2==1.2.0
50
- httpcore==1.0.9
51
- httpx==0.28.1
52
- httpx-sse==0.4.0
53
- huggingface-hub==0.32.5
54
- hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
55
- idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
56
- image-reward==1.5
57
- importlib_metadata==8.7.0
58
- iniconfig==2.1.0
59
- iopath==0.1.10
60
- ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
61
- ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1748713870/work
62
- ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
63
- jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
64
- Jinja2==3.1.6
65
- jiter==0.10.0
66
- jmespath==1.0.1
67
- joblib==1.5.1
68
- jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
69
- jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work
70
- kiwisolver==1.4.8
71
- lightning-utilities==0.14.3
72
- markdown-it-py==3.0.0
73
- MarkupSafe==3.0.2
74
- matplotlib==3.10.3
75
- matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
76
- mdurl==0.1.2
77
- mpmath==1.3.0
78
- multidict==6.4.4
79
- multiprocess==0.70.16
80
- nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
81
- networkx==3.5
82
- numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1749430504934/work/dist/numpy-2.3.0-cp312-cp312-linux_x86_64.whl#sha256=3c4437a0cbe50dbae872ad4cd8dc5316009165bce459c4ffe2c46cd30aba13d4
83
- nvidia-cublas-cu12==12.6.4.1
84
- nvidia-cuda-cupti-cu12==12.6.80
85
- nvidia-cuda-nvrtc-cu12==12.6.77
86
- nvidia-cuda-runtime-cu12==12.6.77
87
- nvidia-cudnn-cu12==9.5.1.17
88
- nvidia-cufft-cu12==11.3.0.4
89
- nvidia-cufile-cu12==1.11.1.6
90
- nvidia-curand-cu12==10.3.7.77
91
- nvidia-cusolver-cu12==11.7.1.2
92
- nvidia-cusparse-cu12==12.5.4.2
93
- nvidia-cusparselt-cu12==0.6.3
94
- nvidia-nccl-cu12==2.26.2
95
- nvidia-nvjitlink-cu12==12.6.85
96
- nvidia-nvtx-cu12==12.6.77
97
- omegaconf==2.3.0
98
- open_clip_torch==2.32.0
99
- openai==1.85.0
100
- opencv-python==4.11.0
101
- opencv-python-headless==4.11.0
102
- packaging==25.0
103
- pandas==2.3.0
104
- parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
105
- pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
106
- pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
107
- pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1746646208260/work
108
- piq==0.8.0
109
- platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
110
- pluggy==1.6.0
111
- portalocker==3.1.1
112
- prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
113
- propcache==0.3.2
114
- protobuf==3.20.3
115
- psutil==7.0.0
116
- ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
117
- pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
118
- pyarrow==20.0.0
119
- pycocoevalcap==1.2
120
- pycocotools==2.0.10
121
- pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
122
- pydantic==2.11.5
123
- pydantic_core==2.33.2
124
- Pygments==2.19.1
125
- pyparsing==3.2.3
126
- PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
127
- pytest==7.2.0
128
- pytest-split==0.8.0
129
- python-dateutil==2.9.0.post0
130
- python-dotenv @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dotenv_1742948348/work
131
- pytz==2025.2
132
- PyYAML==6.0.2
133
- pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1743831245863/work
134
- regex==2024.11.6
135
- replicate==1.0.7
136
- requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1749498106507/work
137
- rich==14.0.0
138
- s3transfer==0.13.0
139
- safetensors==0.5.3
140
- scikit-learn==1.7.0
141
- scipy==1.15.3
142
- sentencepiece==0.2.0
143
- setuptools==80.9.0
144
- shellingham==1.5.4
145
- six==1.17.0
146
- sniffio==1.3.1
147
- soupsieve==2.7
148
- stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
149
- sympy==1.14.0
150
- t2v_metrics==1.2
151
- tabulate==0.9.0
152
- termcolor==3.1.0
153
- threadpoolctl==3.6.0
154
- tiktoken==0.9.0
155
- timm==0.6.13
156
- together==1.5.11
157
- tokenizers==0.15.2
158
- torch==2.7.1
159
- torchmetrics==1.7.2
160
- torchvision==0.22.1
161
- tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003300911/work
162
- tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work
163
- traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
164
- transformers==4.36.1
165
- triton==3.3.1
166
- typer==0.15.4
167
- typing-inspection==0.4.1
168
- typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1748959427/work
169
- tzdata==2025.2
170
- urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1744323578849/work
171
- wcwidth==0.2.13
172
- webdataset==0.2.111
173
- wheel==0.45.1
174
- xxhash==3.5.0
175
- yarl==1.20.1
176
- zipp==3.23.0
177
- zstandard==0.23.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample.py CHANGED
@@ -12,19 +12,19 @@ from benchmark import create_benchmark
12
  def generate_images(api_type: str, benchmarks: List[str]):
13
  images_dir = Path("images")
14
  api = create_api(api_type)
15
-
16
  api_dir = images_dir / api_type
17
  api_dir.mkdir(parents=True, exist_ok=True)
18
-
19
  for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
20
  print(f"\nProcessing benchmark: {benchmark_type}")
21
-
22
  benchmark = create_benchmark(benchmark_type)
23
-
24
  if benchmark_type == "geneval":
25
  benchmark_dir = api_dir / benchmark_type
26
  benchmark_dir.mkdir(parents=True, exist_ok=True)
27
-
28
  metadata_file = benchmark_dir / "metadata.jsonl"
29
  existing_metadata = {}
30
  if metadata_file.exists():
@@ -32,27 +32,29 @@ def generate_images(api_type: str, benchmarks: List[str]):
32
  for line in f:
33
  entry = json.loads(line)
34
  existing_metadata[entry["filepath"]] = entry
35
-
36
- for metadata, folder_name in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
 
 
37
  sample_path = benchmark_dir / folder_name
38
  samples_path = sample_path / "samples"
39
  samples_path.mkdir(parents=True, exist_ok=True)
40
  image_path = samples_path / "0000.png"
41
-
42
  if image_path.exists():
43
  continue
44
-
45
  try:
46
  inference_time = api.generate_image(metadata["prompt"], image_path)
47
-
48
  metadata_entry = {
49
  "filepath": str(image_path),
50
  "prompt": metadata["prompt"],
51
- "inference_time": inference_time
52
  }
53
-
54
  existing_metadata[str(image_path)] = metadata_entry
55
-
56
  except Exception as e:
57
  print(f"\nError generating image for prompt: {metadata['prompt']}")
58
  print(f"Error: {str(e)}")
@@ -60,7 +62,7 @@ def generate_images(api_type: str, benchmarks: List[str]):
60
  else:
61
  benchmark_dir = api_dir / benchmark_type
62
  benchmark_dir.mkdir(parents=True, exist_ok=True)
63
-
64
  metadata_file = benchmark_dir / "metadata.jsonl"
65
  existing_metadata = {}
66
  if metadata_file.exists():
@@ -68,41 +70,45 @@ def generate_images(api_type: str, benchmarks: List[str]):
68
  for line in f:
69
  entry = json.loads(line)
70
  existing_metadata[entry["filepath"]] = entry
71
-
72
- for prompt, image_path in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
 
 
73
  full_image_path = benchmark_dir / image_path
74
-
75
  if full_image_path.exists():
76
  continue
77
-
78
  try:
79
  inference_time = api.generate_image(prompt, full_image_path)
80
-
81
  metadata_entry = {
82
  "filepath": str(image_path),
83
  "prompt": prompt,
84
- "inference_time": inference_time
85
  }
86
-
87
  existing_metadata[str(image_path)] = metadata_entry
88
-
89
  except Exception as e:
90
  print(f"\nError generating image for prompt: {prompt}")
91
  print(f"Error: {str(e)}")
92
  continue
93
-
94
  with open(metadata_file, "w") as f:
95
  for entry in existing_metadata.values():
96
  f.write(json.dumps(entry) + "\n")
97
 
98
 
99
  def main():
100
- parser = argparse.ArgumentParser(description="Generate images for specified benchmarks using a given API")
 
 
101
  parser.add_argument("api_type", help="Type of API to use for image generation")
102
  parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
103
-
104
  args = parser.parse_args()
105
-
106
  generate_images(args.api_type, args.benchmarks)
107
 
108
 
 
12
  def generate_images(api_type: str, benchmarks: List[str]):
13
  images_dir = Path("images")
14
  api = create_api(api_type)
15
+
16
  api_dir = images_dir / api_type
17
  api_dir.mkdir(parents=True, exist_ok=True)
18
+
19
  for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
20
  print(f"\nProcessing benchmark: {benchmark_type}")
21
+
22
  benchmark = create_benchmark(benchmark_type)
23
+
24
  if benchmark_type == "geneval":
25
  benchmark_dir = api_dir / benchmark_type
26
  benchmark_dir.mkdir(parents=True, exist_ok=True)
27
+
28
  metadata_file = benchmark_dir / "metadata.jsonl"
29
  existing_metadata = {}
30
  if metadata_file.exists():
 
32
  for line in f:
33
  entry = json.loads(line)
34
  existing_metadata[entry["filepath"]] = entry
35
+
36
+ for metadata, folder_name in tqdm(
37
+ benchmark, desc=f"Generating images for {benchmark_type}", leave=False
38
+ ):
39
  sample_path = benchmark_dir / folder_name
40
  samples_path = sample_path / "samples"
41
  samples_path.mkdir(parents=True, exist_ok=True)
42
  image_path = samples_path / "0000.png"
43
+
44
  if image_path.exists():
45
  continue
46
+
47
  try:
48
  inference_time = api.generate_image(metadata["prompt"], image_path)
49
+
50
  metadata_entry = {
51
  "filepath": str(image_path),
52
  "prompt": metadata["prompt"],
53
+ "inference_time": inference_time,
54
  }
55
+
56
  existing_metadata[str(image_path)] = metadata_entry
57
+
58
  except Exception as e:
59
  print(f"\nError generating image for prompt: {metadata['prompt']}")
60
  print(f"Error: {str(e)}")
 
62
  else:
63
  benchmark_dir = api_dir / benchmark_type
64
  benchmark_dir.mkdir(parents=True, exist_ok=True)
65
+
66
  metadata_file = benchmark_dir / "metadata.jsonl"
67
  existing_metadata = {}
68
  if metadata_file.exists():
 
70
  for line in f:
71
  entry = json.loads(line)
72
  existing_metadata[entry["filepath"]] = entry
73
+
74
+ for prompt, image_path in tqdm(
75
+ benchmark, desc=f"Generating images for {benchmark_type}", leave=False
76
+ ):
77
  full_image_path = benchmark_dir / image_path
78
+
79
  if full_image_path.exists():
80
  continue
81
+
82
  try:
83
  inference_time = api.generate_image(prompt, full_image_path)
84
+
85
  metadata_entry = {
86
  "filepath": str(image_path),
87
  "prompt": prompt,
88
+ "inference_time": inference_time,
89
  }
90
+
91
  existing_metadata[str(image_path)] = metadata_entry
92
+
93
  except Exception as e:
94
  print(f"\nError generating image for prompt: {prompt}")
95
  print(f"Error: {str(e)}")
96
  continue
97
+
98
  with open(metadata_file, "w") as f:
99
  for entry in existing_metadata.values():
100
  f.write(json.dumps(entry) + "\n")
101
 
102
 
103
  def main():
104
+ parser = argparse.ArgumentParser(
105
+ description="Generate images for specified benchmarks using a given API"
106
+ )
107
  parser.add_argument("api_type", help="Type of API to use for image generation")
108
  parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
109
+
110
  args = parser.parse_args()
111
+
112
  generate_images(args.api_type, args.benchmarks)
113
 
114