davidberenstein1957 commited on
Commit
8c2887a
·
unverified ·
2 Parent(s): f5f9b38 34046e2

Merge pull request #2 from PrunaAI/feat/david

Browse files
.gitignore CHANGED
@@ -175,4 +175,5 @@ cython_debug/
175
 
176
  evaluation_results/
177
  images/
178
- hf_cache/
 
 
175
 
176
  evaluation_results/
177
  images/
178
+ hf_cache/
179
+ *.lock
README.md CHANGED
@@ -1,10 +1,28 @@
1
  # InferBench
2
  Evaluate the quality and efficiency of image gen api's.
3
 
4
- Install dependencies with conda like that:
 
 
5
 
 
 
6
  conda env create -f environment.yml
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
8
 
9
  Create .env file with all the credentials you will need.
10
 
 
1
  # InferBench
2
  Evaluate the quality and efficiency of image gen api's.
3
 
4
+ ## Installation
5
+
6
+ ### Install dependencies
7
 
8
+ Install dependencies with conda like that:
9
+ ```
10
  conda env create -f environment.yml
11
+ ```
12
+
13
+ ### Install uv
14
+
15
+ Install uv with pip like that:
16
+
17
+ ```
18
+ uv venv --python 3.12
19
+ ```
20
+
21
+ ```
22
+ uv sync --all-groups
23
+ ```
24
 
25
+ ## Usage
26
 
27
  Create .env file with all the credentials you will need.
28
 
api/__init__.py CHANGED
@@ -1,31 +1,32 @@
1
  from typing import Type
2
 
 
3
  from api.baseline import BaselineAPI
 
4
  from api.fireworks import FireworksAPI
5
  from api.flux import FluxAPI
6
  from api.pruna import PrunaAPI
7
  from api.pruna_dev import PrunaDevAPI
8
  from api.replicate import ReplicateAPI
9
  from api.together import TogetherAPI
10
- from api.fal import FalAPI
11
- from api.aws import AWSBedrockAPI
12
 
13
  __all__ = [
14
- 'create_api',
15
- 'FluxAPI',
16
- 'BaselineAPI',
17
- 'FireworksAPI',
18
- 'PrunaAPI',
19
- 'ReplicateAPI',
20
- 'TogetherAPI',
21
- 'FalAPI',
22
- 'PrunaDevAPI',
23
  ]
24
 
 
25
  def create_api(api_type: str) -> FluxAPI:
26
  """
27
  Factory function to create API instances.
28
-
29
  Args:
30
  api_type (str): The type of API to create. Must be one of:
31
  - "baseline"
@@ -35,10 +36,10 @@ def create_api(api_type: str) -> FluxAPI:
35
  - "together"
36
  - "fal"
37
  - "aws"
38
-
39
  Returns:
40
  FluxAPI: An instance of the requested API implementation
41
-
42
  Raises:
43
  ValueError: If an invalid API type is provided
44
  """
@@ -47,7 +48,7 @@ def create_api(api_type: str) -> FluxAPI:
47
  if api_type.startswith("pruna_"):
48
  speed_mode = api_type[6:] # Remove "pruna_" prefix
49
  return PrunaAPI(speed_mode)
50
-
51
  api_map: dict[str, Type[FluxAPI]] = {
52
  "baseline": BaselineAPI,
53
  "fireworks": FireworksAPI,
@@ -56,8 +57,10 @@ def create_api(api_type: str) -> FluxAPI:
56
  "fal": FalAPI,
57
  "aws": AWSBedrockAPI,
58
  }
59
-
60
  if api_type not in api_map:
61
- raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'")
62
-
 
 
63
  return api_map[api_type]()
 
1
  from typing import Type
2
 
3
+ from api.aws import AWSBedrockAPI
4
  from api.baseline import BaselineAPI
5
+ from api.fal import FalAPI
6
  from api.fireworks import FireworksAPI
7
  from api.flux import FluxAPI
8
  from api.pruna import PrunaAPI
9
  from api.pruna_dev import PrunaDevAPI
10
  from api.replicate import ReplicateAPI
11
  from api.together import TogetherAPI
 
 
12
 
13
  __all__ = [
14
+ "create_api",
15
+ "FluxAPI",
16
+ "BaselineAPI",
17
+ "FireworksAPI",
18
+ "PrunaAPI",
19
+ "ReplicateAPI",
20
+ "TogetherAPI",
21
+ "FalAPI",
22
+ "PrunaDevAPI",
23
  ]
24
 
25
+
26
  def create_api(api_type: str) -> FluxAPI:
27
  """
28
  Factory function to create API instances.
29
+
30
  Args:
31
  api_type (str): The type of API to create. Must be one of:
32
  - "baseline"
 
36
  - "together"
37
  - "fal"
38
  - "aws"
39
+
40
  Returns:
41
  FluxAPI: An instance of the requested API implementation
42
+
43
  Raises:
44
  ValueError: If an invalid API type is provided
45
  """
 
48
  if api_type.startswith("pruna_"):
49
  speed_mode = api_type[6:] # Remove "pruna_" prefix
50
  return PrunaAPI(speed_mode)
51
+
52
  api_map: dict[str, Type[FluxAPI]] = {
53
  "baseline": BaselineAPI,
54
  "fireworks": FireworksAPI,
 
57
  "fal": FalAPI,
58
  "aws": AWSBedrockAPI,
59
  }
60
+
61
  if api_type not in api_map:
62
+ raise ValueError(
63
+ f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'"
64
+ )
65
+
66
  return api_map[api_type]()
api/aws.py CHANGED
@@ -1,9 +1,8 @@
1
- import os
2
- import time
3
  import base64
4
  import json
 
 
5
  from pathlib import Path
6
- from typing import Any
7
 
8
  import boto3
9
  from dotenv import load_dotenv
@@ -45,23 +44,20 @@ class AWSBedrockAPI(FluxAPI):
45
  try:
46
  # Convert request to JSON and invoke the model
47
  request = json.dumps(native_request)
48
- response = self._client.invoke_model(
49
- modelId=self._model_id,
50
- body=request
51
- )
52
-
53
  # Process the response
54
  model_response = json.loads(response["body"].read())
55
  if not model_response.get("images"):
56
  raise Exception("No images returned from AWS Bedrock API")
57
-
58
  # Save the image
59
  base64_image_data = model_response["images"][0]
60
  self._save_image_from_base64(base64_image_data, save_path)
61
-
62
  except Exception as e:
63
  raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
64
-
65
  end_time = time.time()
66
  return end_time - start_time
67
 
@@ -70,4 +66,4 @@ class AWSBedrockAPI(FluxAPI):
70
  save_path.parent.mkdir(parents=True, exist_ok=True)
71
  image_data = base64.b64decode(base64_data)
72
  with open(save_path, "wb") as f:
73
- f.write(image_data)
 
 
 
1
  import base64
2
  import json
3
+ import os
4
+ import time
5
  from pathlib import Path
 
6
 
7
  import boto3
8
  from dotenv import load_dotenv
 
44
  try:
45
  # Convert request to JSON and invoke the model
46
  request = json.dumps(native_request)
47
+ response = self._client.invoke_model(modelId=self._model_id, body=request)
48
+
 
 
 
49
  # Process the response
50
  model_response = json.loads(response["body"].read())
51
  if not model_response.get("images"):
52
  raise Exception("No images returned from AWS Bedrock API")
53
+
54
  # Save the image
55
  base64_image_data = model_response["images"][0]
56
  self._save_image_from_base64(base64_image_data, save_path)
57
+
58
  except Exception as e:
59
  raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
60
+
61
  end_time = time.time()
62
  return end_time - start_time
63
 
 
66
  save_path.parent.mkdir(parents=True, exist_ok=True)
67
  image_data = base64.b64decode(base64_data)
68
  with open(save_path, "wb") as f:
69
+ f.write(image_data)
api/baseline.py CHANGED
@@ -12,6 +12,7 @@ class BaselineAPI(FluxAPI):
12
  """
13
  As our baseline, we use the Replicate API with go_fast=False.
14
  """
 
15
  def __init__(self):
16
  load_dotenv()
17
  self._api_key = os.getenv("REPLICATE_API_TOKEN")
@@ -24,6 +25,7 @@ class BaselineAPI(FluxAPI):
24
 
25
  def generate_image(self, prompt: str, save_path: Path) -> float:
26
  import replicate
 
27
  start_time = time.time()
28
  result = replicate.run(
29
  "black-forest-labs/flux-dev",
@@ -39,15 +41,15 @@ class BaselineAPI(FluxAPI):
39
  },
40
  )
41
  end_time = time.time()
42
-
43
  if result and len(result) > 0:
44
  self._save_image_from_result(result[0], save_path)
45
  else:
46
  raise Exception("No result returned from Replicate API")
47
-
48
  return end_time - start_time
49
 
50
  def _save_image_from_result(self, result: Any, save_path: Path):
51
  save_path.parent.mkdir(parents=True, exist_ok=True)
52
  with open(save_path, "wb") as f:
53
- f.write(result.read())
 
12
  """
13
  As our baseline, we use the Replicate API with go_fast=False.
14
  """
15
+
16
  def __init__(self):
17
  load_dotenv()
18
  self._api_key = os.getenv("REPLICATE_API_TOKEN")
 
25
 
26
  def generate_image(self, prompt: str, save_path: Path) -> float:
27
  import replicate
28
+
29
  start_time = time.time()
30
  result = replicate.run(
31
  "black-forest-labs/flux-dev",
 
41
  },
42
  )
43
  end_time = time.time()
44
+
45
  if result and len(result) > 0:
46
  self._save_image_from_result(result[0], save_path)
47
  else:
48
  raise Exception("No result returned from Replicate API")
49
+
50
  return end_time - start_time
51
 
52
  def _save_image_from_result(self, result: Any, save_path: Path):
53
  save_path.parent.mkdir(parents=True, exist_ok=True)
54
  with open(save_path, "wb") as f:
55
+ f.write(result.read())
api/fal.py CHANGED
@@ -30,10 +30,10 @@ class FalAPI(FluxAPI):
30
  },
31
  )
32
  end_time = time.time()
33
-
34
  url = result["images"][0]["url"]
35
  self._save_image_from_url(url, save_path)
36
-
37
  return end_time - start_time
38
 
39
  def _save_image_from_url(self, url: str, save_path: Path):
 
30
  },
31
  )
32
  end_time = time.time()
33
+
34
  url = result["images"][0]["url"]
35
  self._save_image_from_url(url, save_path)
36
+
37
  return end_time - start_time
38
 
39
  def _save_image_from_url(self, url: str, save_path: Path):
api/fireworks.py CHANGED
@@ -23,7 +23,7 @@ class FireworksAPI(FluxAPI):
23
 
24
  def generate_image(self, prompt: str, save_path: Path) -> float:
25
  start_time = time.time()
26
-
27
  headers = {
28
  "Content-Type": "application/json",
29
  "Accept": "image/jpeg",
@@ -39,12 +39,12 @@ class FireworksAPI(FluxAPI):
39
  result = requests.post(self._url, headers=headers, json=data)
40
 
41
  end_time = time.time()
42
-
43
  if result.status_code == 200:
44
  self._save_image_from_result(result, save_path)
45
  else:
46
  raise Exception(f"Error: {result.status_code} {result.text}")
47
-
48
  return end_time - start_time
49
 
50
  def _save_image_from_result(self, result: Any, save_path: Path):
 
23
 
24
  def generate_image(self, prompt: str, save_path: Path) -> float:
25
  start_time = time.time()
26
+
27
  headers = {
28
  "Content-Type": "application/json",
29
  "Accept": "image/jpeg",
 
39
  result = requests.post(self._url, headers=headers, json=data)
40
 
41
  end_time = time.time()
42
+
43
  if result.status_code == 200:
44
  self._save_image_from_result(result, save_path)
45
  else:
46
  raise Exception(f"Error: {result.status_code} {result.text}")
47
+
48
  return end_time - start_time
49
 
50
  def _save_image_from_result(self, result: Any, save_path: Path):
api/flux.py CHANGED
@@ -14,7 +14,7 @@ class FluxAPI(ABC):
14
  def name(self) -> str:
15
  """
16
  The name of the API implementation.
17
-
18
  Returns:
19
  str: The name of the specific API implementation
20
  """
@@ -24,11 +24,11 @@ class FluxAPI(ABC):
24
  def generate_image(self, prompt: str, save_path: Path) -> float:
25
  """
26
  Generate an image based on the prompt and save it to the specified path.
27
-
28
  Args:
29
  prompt (str): The text prompt to generate the image from
30
  save_path (Path): The path where the generated image should be saved
31
-
32
  Returns:
33
  float: The time taken for the API call in seconds
34
  """
 
14
  def name(self) -> str:
15
  """
16
  The name of the API implementation.
17
+
18
  Returns:
19
  str: The name of the specific API implementation
20
  """
 
24
  def generate_image(self, prompt: str, save_path: Path) -> float:
25
  """
26
  Generate an image based on the prompt and save it to the specified path.
27
+
28
  Args:
29
  prompt (str): The text prompt to generate the image from
30
  save_path (Path): The path where the generated image should be saved
31
+
32
  Returns:
33
  float: The time taken for the API call in seconds
34
  """
api/pruna.py CHANGED
@@ -3,8 +3,8 @@ import time
3
  from pathlib import Path
4
  from typing import Any
5
 
6
- from dotenv import load_dotenv
7
  import replicate
 
8
 
9
  from api.flux import FluxAPI
10
 
@@ -12,7 +12,9 @@ from api.flux import FluxAPI
12
  class PrunaAPI(FluxAPI):
13
  def __init__(self, speed_mode: str):
14
  self._speed_mode = speed_mode
15
- self._speed_mode_name = speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
 
 
16
  load_dotenv()
17
  self._api_key = os.getenv("REPLICATE_API_TOKEN")
18
  if not self._api_key:
@@ -38,7 +40,7 @@ class PrunaAPI(FluxAPI):
38
  },
39
  )
40
  end_time = time.time()
41
-
42
  if result:
43
  self._save_image_from_result(result, save_path)
44
  else:
 
3
  from pathlib import Path
4
  from typing import Any
5
 
 
6
  import replicate
7
+ from dotenv import load_dotenv
8
 
9
  from api.flux import FluxAPI
10
 
 
12
  class PrunaAPI(FluxAPI):
13
  def __init__(self, speed_mode: str):
14
  self._speed_mode = speed_mode
15
+ self._speed_mode_name = (
16
+ speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
17
+ )
18
  load_dotenv()
19
  self._api_key = os.getenv("REPLICATE_API_TOKEN")
20
  if not self._api_key:
 
40
  },
41
  )
42
  end_time = time.time()
43
+
44
  if result:
45
  self._save_image_from_result(result, save_path)
46
  else:
api/pruna_dev.py CHANGED
@@ -3,8 +3,8 @@ import time
3
  from pathlib import Path
4
  from typing import Any
5
 
6
- from dotenv import load_dotenv
7
  import replicate
 
8
 
9
  from api.flux import FluxAPI
10
 
@@ -36,7 +36,7 @@ class PrunaDevAPI(FluxAPI):
36
  },
37
  )
38
  end_time = time.time()
39
-
40
  if result:
41
  self._save_image_from_result(result, save_path)
42
  else:
@@ -46,4 +46,4 @@ class PrunaDevAPI(FluxAPI):
46
  def _save_image_from_result(self, result: Any, save_path: Path):
47
  save_path.parent.mkdir(parents=True, exist_ok=True)
48
  with open(save_path, "wb") as f:
49
- f.write(result.read())
 
3
  from pathlib import Path
4
  from typing import Any
5
 
 
6
  import replicate
7
+ from dotenv import load_dotenv
8
 
9
  from api.flux import FluxAPI
10
 
 
36
  },
37
  )
38
  end_time = time.time()
39
+
40
  if result:
41
  self._save_image_from_result(result, save_path)
42
  else:
 
46
  def _save_image_from_result(self, result: Any, save_path: Path):
47
  save_path.parent.mkdir(parents=True, exist_ok=True)
48
  with open(save_path, "wb") as f:
49
+ f.write(result.read())
api/replicate.py CHANGED
@@ -3,8 +3,8 @@ import time
3
  from pathlib import Path
4
  from typing import Any
5
 
6
- from dotenv import load_dotenv
7
  import replicate
 
8
 
9
  from api.flux import FluxAPI
10
 
 
3
  from pathlib import Path
4
  from typing import Any
5
 
 
6
  import replicate
7
+ from dotenv import load_dotenv
8
 
9
  from api.flux import FluxAPI
10
 
api/replicate_wan.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import replicate
7
+ from dotenv import load_dotenv
8
+
9
+ from api.flux import FluxAPI
10
+
11
+
12
+ class ReplicateAPI(FluxAPI):
13
+ def __init__(self):
14
+ load_dotenv()
15
+ self._api_key = os.getenv("REPLICATE_API_TOKEN")
16
+ if not self._api_key:
17
+ raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
18
+
19
+ @property
20
+ def name(self) -> str:
21
+ return "replicate_go_fast"
22
+
23
+ def generate_image(self, prompt: str, save_path: Path) -> float:
24
+ start_time = time.time()
25
+ result = replicate.run(
26
+ "black-forest-labs/flux-dev",
27
+ input={
28
+ "seed": 0,
29
+ "prompt": prompt,
30
+ "go_fast": True,
31
+ "guidance": 3.5,
32
+ "num_outputs": 1,
33
+ "aspect_ratio": "1:1",
34
+ "output_format": "png",
35
+ "num_inference_steps": 28,
36
+ },
37
+ )
38
+ end_time = time.time()
39
+ if result and len(result) > 0:
40
+ self._save_image_from_result(result[0], save_path)
41
+ else:
42
+ raise Exception("No result returned from Replicate API")
43
+ return end_time - start_time
44
+
45
+ def _save_image_from_result(self, result: Any, save_path: Path):
46
+ save_path.parent.mkdir(parents=True, exist_ok=True)
47
+ with open(save_path, "wb") as f:
48
+ f.write(result.read())
api/together.py CHANGED
@@ -33,10 +33,10 @@ class TogetherAPI(FluxAPI):
33
  response_format="b64_json",
34
  )
35
  end_time = time.time()
36
- if result and hasattr(result, 'data') and len(result.data) > 0:
37
  self._save_image_from_result(result, save_path)
38
  else:
39
- raise Exception("No result returned from Together API")
40
  return end_time - start_time
41
 
42
  def _save_image_from_result(self, result: Any, save_path: Path):
 
33
  response_format="b64_json",
34
  )
35
  end_time = time.time()
36
+ if result and hasattr(result, "data") and len(result.data) > 0:
37
  self._save_image_from_result(result, save_path)
38
  else:
39
+ raise Exception("No result returned from Together API")
40
  return end_time - start_time
41
 
42
  def _save_image_from_result(self, result: Any, save_path: Path):
benchmark/__init__.py CHANGED
@@ -7,10 +7,14 @@ from benchmark.hps import HPSPrompts
7
  from benchmark.parti import PartiPrompts
8
 
9
 
10
- def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts]:
 
 
 
 
11
  """
12
  Factory function to create benchmark instances.
13
-
14
  Args:
15
  benchmark_type (str): The type of benchmark to create. Must be one of:
16
  - "draw_bench"
@@ -18,10 +22,10 @@ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchP
18
  - "geneval"
19
  - "hps"
20
  - "parti"
21
-
22
  Returns:
23
  An instance of the requested benchmark implementation
24
-
25
  Raises:
26
  ValueError: If an invalid benchmark type is provided
27
  """
@@ -32,8 +36,10 @@ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchP
32
  "hps": HPSPrompts,
33
  "parti": PartiPrompts,
34
  }
35
-
36
  if benchmark_type not in benchmark_map:
37
- raise ValueError(f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}")
38
-
 
 
39
  return benchmark_map[benchmark_type]()
 
7
  from benchmark.parti import PartiPrompts
8
 
9
 
10
+ def create_benchmark(
11
+ benchmark_type: str,
12
+ ) -> Type[
13
+ DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts
14
+ ]:
15
  """
16
  Factory function to create benchmark instances.
17
+
18
  Args:
19
  benchmark_type (str): The type of benchmark to create. Must be one of:
20
  - "draw_bench"
 
22
  - "geneval"
23
  - "hps"
24
  - "parti"
25
+
26
  Returns:
27
  An instance of the requested benchmark implementation
28
+
29
  Raises:
30
  ValueError: If an invalid benchmark type is provided
31
  """
 
36
  "hps": HPSPrompts,
37
  "parti": PartiPrompts,
38
  }
39
+
40
  if benchmark_type not in benchmark_map:
41
+ raise ValueError(
42
+ f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}"
43
+ )
44
+
45
  return benchmark_map[benchmark_type]()
benchmark/draw_bench.py CHANGED
@@ -7,15 +7,15 @@ from datasets import load_dataset
7
  class DrawBenchPrompts:
8
  def __init__(self):
9
  self.dataset = load_dataset("shunk031/DrawBench")["test"]
10
-
11
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
12
  for i, row in enumerate(self.dataset):
13
  yield row["prompts"], Path(f"{i}.png")
14
-
15
  @property
16
  def name(self) -> str:
17
  return "draw_bench"
18
-
19
  @property
20
  def size(self) -> int:
21
  return len(self.dataset)
 
7
  class DrawBenchPrompts:
8
  def __init__(self):
9
  self.dataset = load_dataset("shunk031/DrawBench")["test"]
10
+
11
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
12
  for i, row in enumerate(self.dataset):
13
  yield row["prompts"], Path(f"{i}.png")
14
+
15
  @property
16
  def name(self) -> str:
17
  return "draw_bench"
18
+
19
  @property
20
  def size(self) -> int:
21
  return len(self.dataset)
benchmark/genai_bench.py CHANGED
@@ -8,8 +8,8 @@ class GenAIBenchPrompts:
8
  def __init__(self):
9
  super().__init__()
10
  self._download_genai_bench_files()
11
- prompts_path = Path('downloads/genai_bench/prompts.txt')
12
- with open(prompts_path, 'r') as f:
13
  self.prompts = [line.strip() for line in f if line.strip()]
14
 
15
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
@@ -17,13 +17,13 @@ class GenAIBenchPrompts:
17
  yield prompt, Path(f"{i}.png")
18
 
19
  def _download_genai_bench_files(self) -> None:
20
- folder_name = Path('downloads/genai_bench')
21
  folder_name.mkdir(parents=True, exist_ok=True)
22
  prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
23
  prompts_path = folder_name / "prompts.txt"
24
  if not prompts_path.exists():
25
  response = requests.get(prompts_url)
26
- with open(prompts_path, 'w') as f:
27
  f.write(response.text)
28
 
29
  @property
 
8
  def __init__(self):
9
  super().__init__()
10
  self._download_genai_bench_files()
11
+ prompts_path = Path("downloads/genai_bench/prompts.txt")
12
+ with open(prompts_path, "r") as f:
13
  self.prompts = [line.strip() for line in f if line.strip()]
14
 
15
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
 
17
  yield prompt, Path(f"{i}.png")
18
 
19
  def _download_genai_bench_files(self) -> None:
20
+ folder_name = Path("downloads/genai_bench")
21
  folder_name.mkdir(parents=True, exist_ok=True)
22
  prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
23
  prompts_path = folder_name / "prompts.txt"
24
  if not prompts_path.exists():
25
  response = requests.get(prompts_url)
26
+ with open(prompts_path, "w") as f:
27
  f.write(response.text)
28
 
29
  @property
benchmark/geneval.py CHANGED
@@ -9,32 +9,32 @@ class GenEvalPrompts:
9
  def __init__(self):
10
  super().__init__()
11
  self._download_geneval_file()
12
- metadata_path = Path('downloads/geneval/evaluation_metadata.jsonl')
13
  self.entries: List[Dict[str, Any]] = []
14
- with open(metadata_path, 'r') as f:
15
  for line in f:
16
  if line.strip():
17
  self.entries.append(json.loads(line))
18
-
19
  def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
20
  for i, entry in enumerate(self.entries):
21
  folder_name = f"{i:05d}"
22
  yield entry, folder_name
23
 
24
  def _download_geneval_file(self) -> None:
25
- folder_name = Path('downloads/geneval')
26
  folder_name.mkdir(parents=True, exist_ok=True)
27
  metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
28
  metadata_path = folder_name / "evaluation_metadata.jsonl"
29
  if not metadata_path.exists():
30
  response = requests.get(metadata_url)
31
- with open(metadata_path, 'w') as f:
32
  f.write(response.text)
33
-
34
  @property
35
  def name(self) -> str:
36
  return "geneval"
37
-
38
  @property
39
  def size(self) -> int:
40
  return len(self.entries)
 
9
  def __init__(self):
10
  super().__init__()
11
  self._download_geneval_file()
12
+ metadata_path = Path("downloads/geneval/evaluation_metadata.jsonl")
13
  self.entries: List[Dict[str, Any]] = []
14
+ with open(metadata_path, "r") as f:
15
  for line in f:
16
  if line.strip():
17
  self.entries.append(json.loads(line))
18
+
19
  def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
20
  for i, entry in enumerate(self.entries):
21
  folder_name = f"{i:05d}"
22
  yield entry, folder_name
23
 
24
  def _download_geneval_file(self) -> None:
25
+ folder_name = Path("downloads/geneval")
26
  folder_name.mkdir(parents=True, exist_ok=True)
27
  metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
28
  metadata_path = folder_name / "evaluation_metadata.jsonl"
29
  if not metadata_path.exists():
30
  response = requests.get(metadata_url)
31
+ with open(metadata_path, "w") as f:
32
  f.write(response.text)
33
+
34
  @property
35
  def name(self) -> str:
36
  return "geneval"
37
+
38
  @property
39
  def size(self) -> int:
40
  return len(self.entries)
benchmark/hps.py CHANGED
@@ -9,13 +9,18 @@ import huggingface_hub
9
  class HPSPrompts:
10
  def __init__(self):
11
  super().__init__()
12
- self.hps_prompt_files = ['anime.json', 'concept-art.json', 'paintings.json', 'photo.json']
 
 
 
 
 
13
  self._download_benchmark_prompts()
14
  self.prompts: Dict[str, str] = {}
15
  self._size = 0
16
  for file in self.hps_prompt_files:
17
- category = file.replace('.json', '')
18
- with open(os.path.join('downloads/hps', file), 'r') as f:
19
  prompts = json.load(f)
20
  for i, prompt in enumerate(prompts):
21
  if i == 100:
@@ -23,24 +28,26 @@ class HPSPrompts:
23
  filename = f"{category}_{i:03d}.png"
24
  self.prompts[filename] = prompt
25
  self._size += 1
26
-
27
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
28
  for filename, prompt in self.prompts.items():
29
  yield prompt, Path(filename)
30
-
31
  @property
32
  def name(self) -> str:
33
  return "hps"
34
-
35
  @property
36
  def size(self) -> int:
37
  return self._size
38
 
39
  def _download_benchmark_prompts(self) -> None:
40
- folder_name = Path('downloads/hps')
41
  folder_name.mkdir(parents=True, exist_ok=True)
42
  for file in self.hps_prompt_files:
43
- file_name = huggingface_hub.hf_hub_download("zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset")
 
 
44
  if not os.path.exists(os.path.join(folder_name, file)):
45
  os.symlink(file_name, os.path.join(folder_name, file))
46
 
 
9
  class HPSPrompts:
10
  def __init__(self):
11
  super().__init__()
12
+ self.hps_prompt_files = [
13
+ "anime.json",
14
+ "concept-art.json",
15
+ "paintings.json",
16
+ "photo.json",
17
+ ]
18
  self._download_benchmark_prompts()
19
  self.prompts: Dict[str, str] = {}
20
  self._size = 0
21
  for file in self.hps_prompt_files:
22
+ category = file.replace(".json", "")
23
+ with open(os.path.join("downloads/hps", file), "r") as f:
24
  prompts = json.load(f)
25
  for i, prompt in enumerate(prompts):
26
  if i == 100:
 
28
  filename = f"{category}_{i:03d}.png"
29
  self.prompts[filename] = prompt
30
  self._size += 1
31
+
32
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
33
  for filename, prompt in self.prompts.items():
34
  yield prompt, Path(filename)
35
+
36
  @property
37
  def name(self) -> str:
38
  return "hps"
39
+
40
  @property
41
  def size(self) -> int:
42
  return self._size
43
 
44
  def _download_benchmark_prompts(self) -> None:
45
+ folder_name = Path("downloads/hps")
46
  folder_name.mkdir(parents=True, exist_ok=True)
47
  for file in self.hps_prompt_files:
48
+ file_name = huggingface_hub.hf_hub_download(
49
+ "zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset"
50
+ )
51
  if not os.path.exists(os.path.join(folder_name, file)):
52
  os.symlink(file_name, os.path.join(folder_name, file))
53
 
benchmark/metrics/arniqa.py CHANGED
@@ -8,20 +8,29 @@ from torchmetrics.image.arniqa import ARNIQA
8
 
9
  class ARNIQAMetric:
10
  def __init__(self):
11
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
12
  self.metric = ARNIQA(
13
  regressor_dataset="koniq10k",
14
  reduction="mean",
15
  normalize=True,
16
- autocast=False
17
  )
18
  self.metric.to(self.device)
 
19
  @property
20
  def name(self) -> str:
21
  return "arniqa"
22
-
23
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
24
- image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
 
 
25
  image_tensor = image_tensor.unsqueeze(0).to(self.device)
26
  score = self.metric(image_tensor)
27
  return {"arniqa": score.item()}
 
8
 
9
  class ARNIQAMetric:
10
  def __init__(self):
11
+ self.device = torch.device(
12
+ "cuda"
13
+ if torch.cuda.is_available()
14
+ else "mps"
15
+ if torch.backends.mps.is_available()
16
+ else "cpu"
17
+ )
18
  self.metric = ARNIQA(
19
  regressor_dataset="koniq10k",
20
  reduction="mean",
21
  normalize=True,
22
+ autocast=False,
23
  )
24
  self.metric.to(self.device)
25
+
26
  @property
27
  def name(self) -> str:
28
  return "arniqa"
29
+
30
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
31
+ image_tensor = (
32
+ torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
33
+ )
34
  image_tensor = image_tensor.unsqueeze(0).to(self.device)
35
  score = self.metric(image_tensor)
36
  return {"arniqa": score.item()}
benchmark/metrics/clip.py CHANGED
@@ -8,14 +8,20 @@ from torchmetrics.multimodal.clip_score import CLIPScore
8
 
9
  class CLIPMetric:
10
  def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
11
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
12
  self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
13
  self.metric.to(self.device)
14
-
15
  @property
16
  def name(self) -> str:
17
  return "clip"
18
-
19
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
20
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
21
  image_tensor = image_tensor.to(self.device)
 
8
 
9
  class CLIPMetric:
10
  def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
11
+ self.device = torch.device(
12
+ "cuda"
13
+ if torch.cuda.is_available()
14
+ else "mps"
15
+ if torch.backends.mps.is_available()
16
+ else "cpu"
17
+ )
18
  self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
19
  self.metric.to(self.device)
20
+
21
  @property
22
  def name(self) -> str:
23
  return "clip"
24
+
25
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
26
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
27
  image_tensor = image_tensor.to(self.device)
benchmark/metrics/clip_iqa.py CHANGED
@@ -8,18 +8,22 @@ from torchmetrics.multimodal import CLIPImageQualityAssessment
8
 
9
  class CLIPIQAMetric:
10
  def __init__(self):
11
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
12
  self.metric = CLIPImageQualityAssessment(
13
- model_name_or_path="clip_iqa",
14
- data_range=255.0,
15
- prompts=("quality",)
16
  )
17
  self.metric.to(self.device)
18
 
19
  @property
20
  def name(self) -> str:
21
  return "clip_iqa"
22
-
23
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
24
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
25
  image_tensor = image_tensor.unsqueeze(0)
 
8
 
9
  class CLIPIQAMetric:
10
  def __init__(self):
11
+ self.device = torch.device(
12
+ "cuda"
13
+ if torch.cuda.is_available()
14
+ else "mps"
15
+ if torch.backends.mps.is_available()
16
+ else "cpu"
17
+ )
18
  self.metric = CLIPImageQualityAssessment(
19
+ model_name_or_path="clip_iqa", data_range=255.0, prompts=("quality",)
 
 
20
  )
21
  self.metric.to(self.device)
22
 
23
  @property
24
  def name(self) -> str:
25
  return "clip_iqa"
26
+
27
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
28
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
29
  image_tensor = image_tensor.unsqueeze(0)
benchmark/metrics/hps.py CHANGED
@@ -1,26 +1,32 @@
1
  import os
2
  from typing import Dict
3
 
 
4
  import torch
5
- from PIL import Image
6
  from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
7
- import huggingface_hub
8
- from hpsv2.utils import root_path, hps_version_map
9
 
10
 
11
  class HPSMetric:
12
  def __init__(self):
13
  self.hps_version = "v2.1"
14
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
15
  self.model_dict = {}
16
  self._initialize_model()
17
-
18
  def _initialize_model(self):
19
  if not self.model_dict:
20
  model, preprocess_train, preprocess_val = create_model_and_transforms(
21
- 'ViT-H-14',
22
- 'laion2B-s32B-b79K',
23
- precision='amp',
24
  device=self.device,
25
  jit=False,
26
  force_quick_gelu=False,
@@ -34,44 +40,53 @@ class HPSMetric:
34
  aug_cfg={},
35
  output_dict=True,
36
  with_score_predictor=False,
37
- with_region_predictor=False
38
  )
39
- self.model_dict['model'] = model
40
- self.model_dict['preprocess_val'] = preprocess_val
41
-
42
  # Load checkpoint
43
  if not os.path.exists(root_path):
44
  os.makedirs(root_path)
45
- cp = huggingface_hub.hf_hub_download("xswu/HPSv2", hps_version_map[self.hps_version])
46
-
 
 
47
  checkpoint = torch.load(cp, map_location=self.device)
48
- model.load_state_dict(checkpoint['state_dict'])
49
- self.tokenizer = get_tokenizer('ViT-H-14')
50
  model = model.to(self.device)
51
  model.eval()
52
-
53
  @property
54
  def name(self) -> str:
55
  return "hps"
56
-
57
  def compute_score(
58
  self,
59
  image: Image.Image,
60
  prompt: str,
61
  ) -> Dict[str, float]:
62
- model = self.model_dict['model']
63
- preprocess_val = self.model_dict['preprocess_val']
64
-
65
  with torch.no_grad():
66
  # Process the image
67
- image_tensor = preprocess_val(image).unsqueeze(0).to(device=self.device, non_blocking=True)
 
 
 
 
68
  # Process the prompt
69
  text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
70
  # Calculate the HPS
71
  with torch.cuda.amp.autocast():
72
  outputs = model(image_tensor, text)
73
- image_features, text_features = outputs["image_features"], outputs["text_features"]
 
 
 
74
  logits_per_image = image_features @ text_features.T
75
  hps_score = torch.diagonal(logits_per_image).cpu().numpy()
76
-
77
  return {"hps": float(hps_score[0])}
 
1
  import os
2
  from typing import Dict
3
 
4
+ import huggingface_hub
5
  import torch
 
6
  from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
7
+ from hpsv2.utils import hps_version_map, root_path
8
+ from PIL import Image
9
 
10
 
11
  class HPSMetric:
12
  def __init__(self):
13
  self.hps_version = "v2.1"
14
+ self.device = torch.device(
15
+ "cuda"
16
+ if torch.cuda.is_available()
17
+ else "mps"
18
+ if torch.backends.mps.is_available()
19
+ else "cpu"
20
+ )
21
  self.model_dict = {}
22
  self._initialize_model()
23
+
24
  def _initialize_model(self):
25
  if not self.model_dict:
26
  model, preprocess_train, preprocess_val = create_model_and_transforms(
27
+ "ViT-H-14",
28
+ "laion2B-s32B-b79K",
29
+ precision="amp",
30
  device=self.device,
31
  jit=False,
32
  force_quick_gelu=False,
 
40
  aug_cfg={},
41
  output_dict=True,
42
  with_score_predictor=False,
43
+ with_region_predictor=False,
44
  )
45
+ self.model_dict["model"] = model
46
+ self.model_dict["preprocess_val"] = preprocess_val
47
+
48
  # Load checkpoint
49
  if not os.path.exists(root_path):
50
  os.makedirs(root_path)
51
+ cp = huggingface_hub.hf_hub_download(
52
+ "xswu/HPSv2", hps_version_map[self.hps_version]
53
+ )
54
+
55
  checkpoint = torch.load(cp, map_location=self.device)
56
+ model.load_state_dict(checkpoint["state_dict"])
57
+ self.tokenizer = get_tokenizer("ViT-H-14")
58
  model = model.to(self.device)
59
  model.eval()
60
+
61
  @property
62
  def name(self) -> str:
63
  return "hps"
64
+
65
  def compute_score(
66
  self,
67
  image: Image.Image,
68
  prompt: str,
69
  ) -> Dict[str, float]:
70
+ model = self.model_dict["model"]
71
+ preprocess_val = self.model_dict["preprocess_val"]
72
+
73
  with torch.no_grad():
74
  # Process the image
75
+ image_tensor = (
76
+ preprocess_val(image)
77
+ .unsqueeze(0)
78
+ .to(device=self.device, non_blocking=True)
79
+ )
80
  # Process the prompt
81
  text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
82
  # Calculate the HPS
83
  with torch.cuda.amp.autocast():
84
  outputs = model(image_tensor, text)
85
+ image_features, text_features = (
86
+ outputs["image_features"],
87
+ outputs["text_features"],
88
+ )
89
  logits_per_image = image_features @ text_features.T
90
  hps_score = torch.diagonal(logits_per_image).cpu().numpy()
91
+
92
  return {"hps": float(hps_score[0])}
benchmark/metrics/image_reward.py CHANGED
@@ -3,17 +3,26 @@ import tempfile
3
  from typing import Dict
4
 
5
  import ImageReward as RM
 
6
  from PIL import Image
7
 
8
 
9
  class ImageRewardMetric:
10
  def __init__(self):
11
- self.model = RM.load("ImageReward-v1.0")
12
-
 
 
 
 
 
 
 
 
13
  @property
14
  def name(self) -> str:
15
  return "image_reward"
16
-
17
  def compute_score(
18
  self,
19
  image: Image.Image,
 
3
  from typing import Dict
4
 
5
  import ImageReward as RM
6
+ import torch
7
  from PIL import Image
8
 
9
 
10
  class ImageRewardMetric:
11
  def __init__(self):
12
+ self.device = torch.device(
13
+ "cuda"
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ if torch.backends.mps.is_available()
17
+ else "cpu"
18
+ )
19
+
20
+ self.model = RM.load("ImageReward-v1.0", device=str(self.device))
21
+
22
  @property
23
  def name(self) -> str:
24
  return "image_reward"
25
+
26
  def compute_score(
27
  self,
28
  image: Image.Image,
benchmark/metrics/vqa.py CHANGED
@@ -2,15 +2,26 @@ from pathlib import Path
2
  from typing import Dict
3
 
4
  import t2v_metrics
 
 
5
 
6
  class VQAMetric:
7
  def __init__(self):
8
- self.metric = t2v_metrics.VQAScore(model="clip-flant5-xxl")
9
-
 
 
 
 
 
 
 
 
 
10
  @property
11
  def name(self) -> str:
12
  return "vqa_score"
13
-
14
  def compute_score(
15
  self,
16
  image_path: Path,
 
2
  from typing import Dict
3
 
4
  import t2v_metrics
5
+ import torch
6
+
7
 
8
  class VQAMetric:
9
  def __init__(self):
10
+ self.device = torch.device(
11
+ "cuda"
12
+ if torch.cuda.is_available()
13
+ else "mps"
14
+ if torch.backends.mps.is_available()
15
+ else "cpu"
16
+ )
17
+ self.metric = t2v_metrics.VQAScore(
18
+ model="clip-flant5-xxl", device=str(self.device)
19
+ )
20
+
21
  @property
22
  def name(self) -> str:
23
  return "vqa_score"
24
+
25
  def compute_score(
26
  self,
27
  image_path: Path,
benchmark/parti.py CHANGED
@@ -14,11 +14,11 @@ class PartiPrompts:
14
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
15
  for i, prompt in enumerate(self.prompts):
16
  yield prompt, Path(f"{i}.png")
17
-
18
  @property
19
  def name(self) -> str:
20
  return "parti"
21
-
22
  @property
23
  def size(self) -> int:
24
  return len(self.prompts)
 
14
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
15
  for i, prompt in enumerate(self.prompts):
16
  yield prompt, Path(f"{i}.png")
17
+
18
  @property
19
  def name(self) -> str:
20
  return "parti"
21
+
22
  @property
23
  def size(self) -> int:
24
  return len(self.prompts)
environment.yml CHANGED
@@ -12,7 +12,7 @@ dependencies:
12
  - tqdm
13
  - pip
14
  - pip:
15
- - datasets>=3.5.0
16
  - fal-client>=0.5.9
17
  - hpsv2>=1.2.0
18
  - huggingface-hub>=0.30.2
 
12
  - tqdm
13
  - pip
14
  - pip:
15
+ - datasets==3.6.0
16
  - fal-client>=0.5.9
17
  - hpsv2>=1.2.0
18
  - huggingface-hub>=0.30.2
evaluate.py CHANGED
@@ -1,59 +1,65 @@
1
  import argparse
2
  import json
 
3
  from pathlib import Path
4
  from typing import Dict
5
- import warnings
6
 
7
- from benchmark import create_benchmark
8
- from benchmark.metrics import create_metric
9
  import numpy as np
10
  from PIL import Image
11
  from tqdm import tqdm
12
 
 
 
13
 
14
  warnings.filterwarnings("ignore", category=FutureWarning)
15
 
16
 
17
- def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
 
 
18
  """
19
  Evaluate a benchmark's images using its specific metrics.
20
-
21
  Args:
22
  benchmark_type (str): Type of benchmark to evaluate
23
  api_type (str): Type of API used to generate images
24
  images_dir (Path): Base directory containing generated images
25
-
26
  Returns:
27
  Dict containing evaluation results
28
  """
29
  benchmark = create_benchmark(benchmark_type)
30
-
31
  benchmark_dir = images_dir / api_type / benchmark_type
32
  metadata_file = benchmark_dir / "metadata.jsonl"
33
-
34
  if not metadata_file.exists():
35
- raise FileNotFoundError(f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first.")
36
-
 
 
37
  metadata = []
38
  with open(metadata_file, "r") as f:
39
  for line in f:
40
  metadata.append(json.loads(line))
41
-
42
- metrics = {metric_type: create_metric(metric_type) for metric_type in benchmark.metrics}
43
-
 
 
44
  results = {
45
  "api": api_type,
46
  "benchmark": benchmark_type,
47
  "metrics": {metric: 0.0 for metric in benchmark.metrics},
48
- "total_images": len(metadata)
49
  }
50
  inference_times = []
51
-
52
  for entry in tqdm(metadata):
53
  image_path = benchmark_dir / entry["filepath"]
54
  if not image_path.exists():
55
  continue
56
-
57
  for metric_type, metric in metrics.items():
58
  try:
59
  if metric_type == "vqa":
@@ -64,26 +70,30 @@ def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Pa
64
  results["metrics"][metric_type] += score[metric_type]
65
  except Exception as e:
66
  print(f"Error computing {metric_type} for {image_path}: {str(e)}")
67
-
68
  inference_times.append(entry["inference_time"])
69
-
70
  for metric in results["metrics"]:
71
  results["metrics"][metric] /= len(metadata)
72
  results["median_inference_time"] = np.median(inference_times).item()
73
-
74
  return results
75
 
76
 
77
  def main():
78
- parser = argparse.ArgumentParser(description="Evaluate generated images using benchmark-specific metrics")
 
 
79
  parser.add_argument("api_type", help="Type of API to evaluate")
80
- parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to evaluate")
81
-
 
 
82
  args = parser.parse_args()
83
-
84
  results_dir = Path("evaluation_results")
85
  results_dir.mkdir(exist_ok=True)
86
-
87
  results_file = results_dir / f"{args.api_type}.jsonl"
88
  existing_results = set()
89
 
@@ -97,15 +107,15 @@ def main():
97
  if benchmark_type in existing_results:
98
  print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
99
  continue
100
-
101
  try:
102
  print(f"Evaluating {args.api_type}/{benchmark_type}")
103
  results = evaluate_benchmark(benchmark_type, args.api_type)
104
-
105
  # Append results to file
106
  with open(results_file, "a") as f:
107
  f.write(json.dumps(results) + "\n")
108
-
109
  except Exception as e:
110
  print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
111
 
 
1
  import argparse
2
  import json
3
+ import warnings
4
  from pathlib import Path
5
  from typing import Dict
 
6
 
 
 
7
  import numpy as np
8
  from PIL import Image
9
  from tqdm import tqdm
10
 
11
+ from benchmark import create_benchmark
12
+ from benchmark.metrics import create_metric
13
 
14
  warnings.filterwarnings("ignore", category=FutureWarning)
15
 
16
 
17
+ def evaluate_benchmark(
18
+ benchmark_type: str, api_type: str, images_dir: Path = Path("images")
19
+ ) -> Dict:
20
  """
21
  Evaluate a benchmark's images using its specific metrics.
22
+
23
  Args:
24
  benchmark_type (str): Type of benchmark to evaluate
25
  api_type (str): Type of API used to generate images
26
  images_dir (Path): Base directory containing generated images
27
+
28
  Returns:
29
  Dict containing evaluation results
30
  """
31
  benchmark = create_benchmark(benchmark_type)
32
+
33
  benchmark_dir = images_dir / api_type / benchmark_type
34
  metadata_file = benchmark_dir / "metadata.jsonl"
35
+
36
  if not metadata_file.exists():
37
+ raise FileNotFoundError(
38
+ f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first."
39
+ )
40
+
41
  metadata = []
42
  with open(metadata_file, "r") as f:
43
  for line in f:
44
  metadata.append(json.loads(line))
45
+
46
+ metrics = {
47
+ metric_type: create_metric(metric_type) for metric_type in benchmark.metrics
48
+ }
49
+
50
  results = {
51
  "api": api_type,
52
  "benchmark": benchmark_type,
53
  "metrics": {metric: 0.0 for metric in benchmark.metrics},
54
+ "total_images": len(metadata),
55
  }
56
  inference_times = []
57
+
58
  for entry in tqdm(metadata):
59
  image_path = benchmark_dir / entry["filepath"]
60
  if not image_path.exists():
61
  continue
62
+
63
  for metric_type, metric in metrics.items():
64
  try:
65
  if metric_type == "vqa":
 
70
  results["metrics"][metric_type] += score[metric_type]
71
  except Exception as e:
72
  print(f"Error computing {metric_type} for {image_path}: {str(e)}")
73
+
74
  inference_times.append(entry["inference_time"])
75
+
76
  for metric in results["metrics"]:
77
  results["metrics"][metric] /= len(metadata)
78
  results["median_inference_time"] = np.median(inference_times).item()
79
+
80
  return results
81
 
82
 
83
  def main():
84
+ parser = argparse.ArgumentParser(
85
+ description="Evaluate generated images using benchmark-specific metrics"
86
+ )
87
  parser.add_argument("api_type", help="Type of API to evaluate")
88
+ parser.add_argument(
89
+ "benchmarks", nargs="+", help="List of benchmark types to evaluate"
90
+ )
91
+
92
  args = parser.parse_args()
93
+
94
  results_dir = Path("evaluation_results")
95
  results_dir.mkdir(exist_ok=True)
96
+
97
  results_file = results_dir / f"{args.api_type}.jsonl"
98
  existing_results = set()
99
 
 
107
  if benchmark_type in existing_results:
108
  print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
109
  continue
110
+
111
  try:
112
  print(f"Evaluating {args.api_type}/{benchmark_type}")
113
  results = evaluate_benchmark(benchmark_type, args.api_type)
114
+
115
  # Append results to file
116
  with open(results_file, "a") as f:
117
  f.write(json.dumps(results) + "\n")
118
+
119
  except Exception as e:
120
  print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
121
 
nils_installs.txt DELETED
@@ -1,177 +0,0 @@
1
- accelerate==1.7.0
2
- aiohappyeyeballs==2.6.1
3
- aiohttp==3.12.12
4
- aiosignal==1.3.2
5
- annotated-types==0.7.0
6
- antlr4-python3-runtime==4.9.3
7
- anyio==4.9.0
8
- args==0.1.0
9
- asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
10
- attrs==25.3.0
11
- beautifulsoup4==4.13.4
12
- boto3==1.38.33
13
- botocore==1.38.33
14
- braceexpand==0.1.7
15
- Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1749229842835/work
16
- certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1746569525376/work/certifi
17
- cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725560558132/work
18
- charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1746214863626/work
19
- click==8.1.8
20
- clint==0.5.1
21
- clip==0.2.0
22
- colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
23
- comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
24
- contourpy==1.3.2
25
- cycler==0.12.1
26
- datasets==3.6.0
27
- debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321241074/work
28
- decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
29
- diffusers==0.31.0
30
- dill==0.3.8
31
- distro==1.9.0
32
- einops==0.8.1
33
- eval_type_backport==0.2.2
34
- exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
35
- executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
36
- fairscale==0.4.13
37
- fal_client==0.7.0
38
- filelock==3.18.0
39
- fire==0.4.0
40
- fonttools==4.58.2
41
- frozenlist==1.7.0
42
- fsspec==2025.3.0
43
- ftfy==6.3.1
44
- gdown==5.2.0
45
- h11==0.16.0
46
- h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
47
- hf-xet==1.1.3
48
- hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
49
- hpsv2==1.2.0
50
- httpcore==1.0.9
51
- httpx==0.28.1
52
- httpx-sse==0.4.0
53
- huggingface-hub==0.32.5
54
- hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
55
- idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
56
- image-reward==1.5
57
- importlib_metadata==8.7.0
58
- iniconfig==2.1.0
59
- iopath==0.1.10
60
- ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
61
- ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1748713870/work
62
- ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
63
- jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
64
- Jinja2==3.1.6
65
- jiter==0.10.0
66
- jmespath==1.0.1
67
- joblib==1.5.1
68
- jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
69
- jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work
70
- kiwisolver==1.4.8
71
- lightning-utilities==0.14.3
72
- markdown-it-py==3.0.0
73
- MarkupSafe==3.0.2
74
- matplotlib==3.10.3
75
- matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
76
- mdurl==0.1.2
77
- mpmath==1.3.0
78
- multidict==6.4.4
79
- multiprocess==0.70.16
80
- nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
81
- networkx==3.5
82
- numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1749430504934/work/dist/numpy-2.3.0-cp312-cp312-linux_x86_64.whl#sha256=3c4437a0cbe50dbae872ad4cd8dc5316009165bce459c4ffe2c46cd30aba13d4
83
- nvidia-cublas-cu12==12.6.4.1
84
- nvidia-cuda-cupti-cu12==12.6.80
85
- nvidia-cuda-nvrtc-cu12==12.6.77
86
- nvidia-cuda-runtime-cu12==12.6.77
87
- nvidia-cudnn-cu12==9.5.1.17
88
- nvidia-cufft-cu12==11.3.0.4
89
- nvidia-cufile-cu12==1.11.1.6
90
- nvidia-curand-cu12==10.3.7.77
91
- nvidia-cusolver-cu12==11.7.1.2
92
- nvidia-cusparse-cu12==12.5.4.2
93
- nvidia-cusparselt-cu12==0.6.3
94
- nvidia-nccl-cu12==2.26.2
95
- nvidia-nvjitlink-cu12==12.6.85
96
- nvidia-nvtx-cu12==12.6.77
97
- omegaconf==2.3.0
98
- open_clip_torch==2.32.0
99
- openai==1.85.0
100
- opencv-python==4.11.0
101
- opencv-python-headless==4.11.0
102
- packaging==25.0
103
- pandas==2.3.0
104
- parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
105
- pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
106
- pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
107
- pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1746646208260/work
108
- piq==0.8.0
109
- platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
110
- pluggy==1.6.0
111
- portalocker==3.1.1
112
- prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
113
- propcache==0.3.2
114
- protobuf==3.20.3
115
- psutil==7.0.0
116
- ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
117
- pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
118
- pyarrow==20.0.0
119
- pycocoevalcap==1.2
120
- pycocotools==2.0.10
121
- pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
122
- pydantic==2.11.5
123
- pydantic_core==2.33.2
124
- Pygments==2.19.1
125
- pyparsing==3.2.3
126
- PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
127
- pytest==7.2.0
128
- pytest-split==0.8.0
129
- python-dateutil==2.9.0.post0
130
- python-dotenv @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dotenv_1742948348/work
131
- pytz==2025.2
132
- PyYAML==6.0.2
133
- pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1743831245863/work
134
- regex==2024.11.6
135
- replicate==1.0.7
136
- requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1749498106507/work
137
- rich==14.0.0
138
- s3transfer==0.13.0
139
- safetensors==0.5.3
140
- scikit-learn==1.7.0
141
- scipy==1.15.3
142
- sentencepiece==0.2.0
143
- setuptools==80.9.0
144
- shellingham==1.5.4
145
- six==1.17.0
146
- sniffio==1.3.1
147
- soupsieve==2.7
148
- stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
149
- sympy==1.14.0
150
- t2v_metrics==1.2
151
- tabulate==0.9.0
152
- termcolor==3.1.0
153
- threadpoolctl==3.6.0
154
- tiktoken==0.9.0
155
- timm==0.6.13
156
- together==1.5.11
157
- tokenizers==0.15.2
158
- torch==2.7.1
159
- torchmetrics==1.7.2
160
- torchvision==0.22.1
161
- tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003300911/work
162
- tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work
163
- traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
164
- transformers==4.36.1
165
- triton==3.3.1
166
- typer==0.15.4
167
- typing-inspection==0.4.1
168
- typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1748959427/work
169
- tzdata==2025.2
170
- urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1744323578849/work
171
- wcwidth==0.2.13
172
- webdataset==0.2.111
173
- wheel==0.45.1
174
- xxhash==3.5.0
175
- yarl==1.20.1
176
- zipp==3.23.0
177
- zstandard==0.23.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample.py CHANGED
@@ -12,19 +12,19 @@ from benchmark import create_benchmark
12
  def generate_images(api_type: str, benchmarks: List[str]):
13
  images_dir = Path("images")
14
  api = create_api(api_type)
15
-
16
  api_dir = images_dir / api_type
17
  api_dir.mkdir(parents=True, exist_ok=True)
18
-
19
  for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
20
  print(f"\nProcessing benchmark: {benchmark_type}")
21
-
22
  benchmark = create_benchmark(benchmark_type)
23
-
24
  if benchmark_type == "geneval":
25
  benchmark_dir = api_dir / benchmark_type
26
  benchmark_dir.mkdir(parents=True, exist_ok=True)
27
-
28
  metadata_file = benchmark_dir / "metadata.jsonl"
29
  existing_metadata = {}
30
  if metadata_file.exists():
@@ -32,35 +32,44 @@ def generate_images(api_type: str, benchmarks: List[str]):
32
  for line in f:
33
  entry = json.loads(line)
34
  existing_metadata[entry["filepath"]] = entry
35
-
36
- for metadata, folder_name in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
37
- sample_path = benchmark_dir / folder_name
38
- samples_path = sample_path / "samples"
39
- samples_path.mkdir(parents=True, exist_ok=True)
40
- image_path = samples_path / "0000.png"
41
-
42
- if image_path.exists():
43
- continue
44
-
45
- try:
46
- inference_time = api.generate_image(metadata["prompt"], image_path)
47
-
48
- metadata_entry = {
49
- "filepath": str(image_path),
50
- "prompt": metadata["prompt"],
51
- "inference_time": inference_time
52
- }
53
-
54
- existing_metadata[str(image_path)] = metadata_entry
55
-
56
- except Exception as e:
57
- print(f"\nError generating image for prompt: {metadata['prompt']}")
58
- print(f"Error: {str(e)}")
59
- continue
 
 
 
 
 
 
 
 
 
60
  else:
61
  benchmark_dir = api_dir / benchmark_type
62
  benchmark_dir.mkdir(parents=True, exist_ok=True)
63
-
64
  metadata_file = benchmark_dir / "metadata.jsonl"
65
  existing_metadata = {}
66
  if metadata_file.exists():
@@ -68,41 +77,47 @@ def generate_images(api_type: str, benchmarks: List[str]):
68
  for line in f:
69
  entry = json.loads(line)
70
  existing_metadata[entry["filepath"]] = entry
71
-
72
- for prompt, image_path in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
73
- full_image_path = benchmark_dir / image_path
74
-
75
- if full_image_path.exists():
76
- continue
77
-
78
- try:
79
- inference_time = api.generate_image(prompt, full_image_path)
80
-
81
- metadata_entry = {
82
- "filepath": str(image_path),
83
- "prompt": prompt,
84
- "inference_time": inference_time
85
- }
86
-
87
- existing_metadata[str(image_path)] = metadata_entry
88
-
89
- except Exception as e:
90
- print(f"\nError generating image for prompt: {prompt}")
91
- print(f"Error: {str(e)}")
92
- continue
93
-
94
- with open(metadata_file, "w") as f:
95
- for entry in existing_metadata.values():
96
- f.write(json.dumps(entry) + "\n")
 
 
 
 
97
 
98
 
99
  def main():
100
- parser = argparse.ArgumentParser(description="Generate images for specified benchmarks using a given API")
 
 
101
  parser.add_argument("api_type", help="Type of API to use for image generation")
102
  parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
103
-
104
  args = parser.parse_args()
105
-
106
  generate_images(args.api_type, args.benchmarks)
107
 
108
 
 
12
  def generate_images(api_type: str, benchmarks: List[str]):
13
  images_dir = Path("images")
14
  api = create_api(api_type)
15
+
16
  api_dir = images_dir / api_type
17
  api_dir.mkdir(parents=True, exist_ok=True)
18
+
19
  for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
20
  print(f"\nProcessing benchmark: {benchmark_type}")
21
+
22
  benchmark = create_benchmark(benchmark_type)
23
+
24
  if benchmark_type == "geneval":
25
  benchmark_dir = api_dir / benchmark_type
26
  benchmark_dir.mkdir(parents=True, exist_ok=True)
27
+
28
  metadata_file = benchmark_dir / "metadata.jsonl"
29
  existing_metadata = {}
30
  if metadata_file.exists():
 
32
  for line in f:
33
  entry = json.loads(line)
34
  existing_metadata[entry["filepath"]] = entry
35
+
36
+ with open(metadata_file, "a") as f:
37
+ for metadata, folder_name in tqdm(
38
+ benchmark,
39
+ desc=f"Generating images for {benchmark_type}",
40
+ leave=False,
41
+ ):
42
+ sample_path = benchmark_dir / folder_name
43
+ samples_path = sample_path / "samples"
44
+ samples_path.mkdir(parents=True, exist_ok=True)
45
+ image_path = samples_path / "0000.png"
46
+
47
+ if image_path.exists():
48
+ continue
49
+
50
+ try:
51
+ inference_time = api.generate_image(
52
+ metadata["prompt"], image_path
53
+ )
54
+
55
+ metadata_entry = {
56
+ "filepath": str(image_path),
57
+ "prompt": metadata["prompt"],
58
+ "inference_time": inference_time,
59
+ }
60
+
61
+ f.write(json.dumps(metadata_entry) + "\n")
62
+
63
+ except Exception as e:
64
+ print(
65
+ f"\nError generating image for prompt: {metadata['prompt']}"
66
+ )
67
+ print(f"Error: {str(e)}")
68
+ continue
69
  else:
70
  benchmark_dir = api_dir / benchmark_type
71
  benchmark_dir.mkdir(parents=True, exist_ok=True)
72
+
73
  metadata_file = benchmark_dir / "metadata.jsonl"
74
  existing_metadata = {}
75
  if metadata_file.exists():
 
77
  for line in f:
78
  entry = json.loads(line)
79
  existing_metadata[entry["filepath"]] = entry
80
+
81
+ with open(metadata_file, "a") as f:
82
+ for prompt, image_path in tqdm(
83
+ benchmark,
84
+ desc=f"Generating images for {benchmark_type}",
85
+ leave=False,
86
+ ):
87
+ if image_path in existing_metadata:
88
+ continue
89
+
90
+ full_image_path = benchmark_dir / image_path
91
+
92
+ if full_image_path.exists():
93
+ continue
94
+
95
+ try:
96
+ inference_time = api.generate_image(prompt, full_image_path)
97
+
98
+ metadata_entry = {
99
+ "filepath": str(image_path),
100
+ "prompt": prompt,
101
+ "inference_time": inference_time,
102
+ }
103
+
104
+ f.write(json.dumps(metadata_entry) + "\n")
105
+
106
+ except Exception as e:
107
+ print(f"\nError generating image for prompt: {prompt}")
108
+ print(f"Error: {str(e)}")
109
+ continue
110
 
111
 
112
  def main():
113
+ parser = argparse.ArgumentParser(
114
+ description="Generate images for specified benchmarks using a given API"
115
+ )
116
  parser.add_argument("api_type", help="Type of API to use for image generation")
117
  parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
118
+
119
  args = parser.parse_args()
120
+
121
  generate_images(args.api_type, args.benchmarks)
122
 
123