davidberenstein1957 commited on
Commit
34046e2
·
1 Parent(s): 9fa4df6

refactor: improve code formatting and organization across multiple API and benchmark files

Browse files
api/__init__.py CHANGED
@@ -1,31 +1,32 @@
1
  from typing import Type
2
 
 
3
  from api.baseline import BaselineAPI
 
4
  from api.fireworks import FireworksAPI
5
  from api.flux import FluxAPI
6
  from api.pruna import PrunaAPI
7
  from api.pruna_dev import PrunaDevAPI
8
  from api.replicate import ReplicateAPI
9
  from api.together import TogetherAPI
10
- from api.fal import FalAPI
11
- from api.aws import AWSBedrockAPI
12
 
13
  __all__ = [
14
- 'create_api',
15
- 'FluxAPI',
16
- 'BaselineAPI',
17
- 'FireworksAPI',
18
- 'PrunaAPI',
19
- 'ReplicateAPI',
20
- 'TogetherAPI',
21
- 'FalAPI',
22
- 'PrunaDevAPI',
23
  ]
24
 
 
25
  def create_api(api_type: str) -> FluxAPI:
26
  """
27
  Factory function to create API instances.
28
-
29
  Args:
30
  api_type (str): The type of API to create. Must be one of:
31
  - "baseline"
@@ -35,10 +36,10 @@ def create_api(api_type: str) -> FluxAPI:
35
  - "together"
36
  - "fal"
37
  - "aws"
38
-
39
  Returns:
40
  FluxAPI: An instance of the requested API implementation
41
-
42
  Raises:
43
  ValueError: If an invalid API type is provided
44
  """
@@ -47,7 +48,7 @@ def create_api(api_type: str) -> FluxAPI:
47
  if api_type.startswith("pruna_"):
48
  speed_mode = api_type[6:] # Remove "pruna_" prefix
49
  return PrunaAPI(speed_mode)
50
-
51
  api_map: dict[str, Type[FluxAPI]] = {
52
  "baseline": BaselineAPI,
53
  "fireworks": FireworksAPI,
@@ -56,8 +57,10 @@ def create_api(api_type: str) -> FluxAPI:
56
  "fal": FalAPI,
57
  "aws": AWSBedrockAPI,
58
  }
59
-
60
  if api_type not in api_map:
61
- raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'")
62
-
 
 
63
  return api_map[api_type]()
 
1
  from typing import Type
2
 
3
+ from api.aws import AWSBedrockAPI
4
  from api.baseline import BaselineAPI
5
+ from api.fal import FalAPI
6
  from api.fireworks import FireworksAPI
7
  from api.flux import FluxAPI
8
  from api.pruna import PrunaAPI
9
  from api.pruna_dev import PrunaDevAPI
10
  from api.replicate import ReplicateAPI
11
  from api.together import TogetherAPI
 
 
12
 
13
  __all__ = [
14
+ "create_api",
15
+ "FluxAPI",
16
+ "BaselineAPI",
17
+ "FireworksAPI",
18
+ "PrunaAPI",
19
+ "ReplicateAPI",
20
+ "TogetherAPI",
21
+ "FalAPI",
22
+ "PrunaDevAPI",
23
  ]
24
 
25
+
26
  def create_api(api_type: str) -> FluxAPI:
27
  """
28
  Factory function to create API instances.
29
+
30
  Args:
31
  api_type (str): The type of API to create. Must be one of:
32
  - "baseline"
 
36
  - "together"
37
  - "fal"
38
  - "aws"
39
+
40
  Returns:
41
  FluxAPI: An instance of the requested API implementation
42
+
43
  Raises:
44
  ValueError: If an invalid API type is provided
45
  """
 
48
  if api_type.startswith("pruna_"):
49
  speed_mode = api_type[6:] # Remove "pruna_" prefix
50
  return PrunaAPI(speed_mode)
51
+
52
  api_map: dict[str, Type[FluxAPI]] = {
53
  "baseline": BaselineAPI,
54
  "fireworks": FireworksAPI,
 
57
  "fal": FalAPI,
58
  "aws": AWSBedrockAPI,
59
  }
60
+
61
  if api_type not in api_map:
62
+ raise ValueError(
63
+ f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'"
64
+ )
65
+
66
  return api_map[api_type]()
api/aws.py CHANGED
@@ -1,9 +1,8 @@
1
- import os
2
- import time
3
  import base64
4
  import json
 
 
5
  from pathlib import Path
6
- from typing import Any
7
 
8
  import boto3
9
  from dotenv import load_dotenv
@@ -45,23 +44,20 @@ class AWSBedrockAPI(FluxAPI):
45
  try:
46
  # Convert request to JSON and invoke the model
47
  request = json.dumps(native_request)
48
- response = self._client.invoke_model(
49
- modelId=self._model_id,
50
- body=request
51
- )
52
-
53
  # Process the response
54
  model_response = json.loads(response["body"].read())
55
  if not model_response.get("images"):
56
  raise Exception("No images returned from AWS Bedrock API")
57
-
58
  # Save the image
59
  base64_image_data = model_response["images"][0]
60
  self._save_image_from_base64(base64_image_data, save_path)
61
-
62
  except Exception as e:
63
  raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
64
-
65
  end_time = time.time()
66
  return end_time - start_time
67
 
@@ -70,4 +66,4 @@ class AWSBedrockAPI(FluxAPI):
70
  save_path.parent.mkdir(parents=True, exist_ok=True)
71
  image_data = base64.b64decode(base64_data)
72
  with open(save_path, "wb") as f:
73
- f.write(image_data)
 
 
 
1
  import base64
2
  import json
3
+ import os
4
+ import time
5
  from pathlib import Path
 
6
 
7
  import boto3
8
  from dotenv import load_dotenv
 
44
  try:
45
  # Convert request to JSON and invoke the model
46
  request = json.dumps(native_request)
47
+ response = self._client.invoke_model(modelId=self._model_id, body=request)
48
+
 
 
 
49
  # Process the response
50
  model_response = json.loads(response["body"].read())
51
  if not model_response.get("images"):
52
  raise Exception("No images returned from AWS Bedrock API")
53
+
54
  # Save the image
55
  base64_image_data = model_response["images"][0]
56
  self._save_image_from_base64(base64_image_data, save_path)
57
+
58
  except Exception as e:
59
  raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
60
+
61
  end_time = time.time()
62
  return end_time - start_time
63
 
 
66
  save_path.parent.mkdir(parents=True, exist_ok=True)
67
  image_data = base64.b64decode(base64_data)
68
  with open(save_path, "wb") as f:
69
+ f.write(image_data)
api/baseline.py CHANGED
@@ -12,6 +12,7 @@ class BaselineAPI(FluxAPI):
12
  """
13
  As our baseline, we use the Replicate API with go_fast=False.
14
  """
 
15
  def __init__(self):
16
  load_dotenv()
17
  self._api_key = os.getenv("REPLICATE_API_TOKEN")
@@ -24,6 +25,7 @@ class BaselineAPI(FluxAPI):
24
 
25
  def generate_image(self, prompt: str, save_path: Path) -> float:
26
  import replicate
 
27
  start_time = time.time()
28
  result = replicate.run(
29
  "black-forest-labs/flux-dev",
@@ -39,15 +41,15 @@ class BaselineAPI(FluxAPI):
39
  },
40
  )
41
  end_time = time.time()
42
-
43
  if result and len(result) > 0:
44
  self._save_image_from_result(result[0], save_path)
45
  else:
46
  raise Exception("No result returned from Replicate API")
47
-
48
  return end_time - start_time
49
 
50
  def _save_image_from_result(self, result: Any, save_path: Path):
51
  save_path.parent.mkdir(parents=True, exist_ok=True)
52
  with open(save_path, "wb") as f:
53
- f.write(result.read())
 
12
  """
13
  As our baseline, we use the Replicate API with go_fast=False.
14
  """
15
+
16
  def __init__(self):
17
  load_dotenv()
18
  self._api_key = os.getenv("REPLICATE_API_TOKEN")
 
25
 
26
  def generate_image(self, prompt: str, save_path: Path) -> float:
27
  import replicate
28
+
29
  start_time = time.time()
30
  result = replicate.run(
31
  "black-forest-labs/flux-dev",
 
41
  },
42
  )
43
  end_time = time.time()
44
+
45
  if result and len(result) > 0:
46
  self._save_image_from_result(result[0], save_path)
47
  else:
48
  raise Exception("No result returned from Replicate API")
49
+
50
  return end_time - start_time
51
 
52
  def _save_image_from_result(self, result: Any, save_path: Path):
53
  save_path.parent.mkdir(parents=True, exist_ok=True)
54
  with open(save_path, "wb") as f:
55
+ f.write(result.read())
api/fal.py CHANGED
@@ -30,10 +30,10 @@ class FalAPI(FluxAPI):
30
  },
31
  )
32
  end_time = time.time()
33
-
34
  url = result["images"][0]["url"]
35
  self._save_image_from_url(url, save_path)
36
-
37
  return end_time - start_time
38
 
39
  def _save_image_from_url(self, url: str, save_path: Path):
 
30
  },
31
  )
32
  end_time = time.time()
33
+
34
  url = result["images"][0]["url"]
35
  self._save_image_from_url(url, save_path)
36
+
37
  return end_time - start_time
38
 
39
  def _save_image_from_url(self, url: str, save_path: Path):
api/fireworks.py CHANGED
@@ -23,7 +23,7 @@ class FireworksAPI(FluxAPI):
23
 
24
  def generate_image(self, prompt: str, save_path: Path) -> float:
25
  start_time = time.time()
26
-
27
  headers = {
28
  "Content-Type": "application/json",
29
  "Accept": "image/jpeg",
@@ -39,12 +39,12 @@ class FireworksAPI(FluxAPI):
39
  result = requests.post(self._url, headers=headers, json=data)
40
 
41
  end_time = time.time()
42
-
43
  if result.status_code == 200:
44
  self._save_image_from_result(result, save_path)
45
  else:
46
  raise Exception(f"Error: {result.status_code} {result.text}")
47
-
48
  return end_time - start_time
49
 
50
  def _save_image_from_result(self, result: Any, save_path: Path):
 
23
 
24
  def generate_image(self, prompt: str, save_path: Path) -> float:
25
  start_time = time.time()
26
+
27
  headers = {
28
  "Content-Type": "application/json",
29
  "Accept": "image/jpeg",
 
39
  result = requests.post(self._url, headers=headers, json=data)
40
 
41
  end_time = time.time()
42
+
43
  if result.status_code == 200:
44
  self._save_image_from_result(result, save_path)
45
  else:
46
  raise Exception(f"Error: {result.status_code} {result.text}")
47
+
48
  return end_time - start_time
49
 
50
  def _save_image_from_result(self, result: Any, save_path: Path):
api/flux.py CHANGED
@@ -14,7 +14,7 @@ class FluxAPI(ABC):
14
  def name(self) -> str:
15
  """
16
  The name of the API implementation.
17
-
18
  Returns:
19
  str: The name of the specific API implementation
20
  """
@@ -24,11 +24,11 @@ class FluxAPI(ABC):
24
  def generate_image(self, prompt: str, save_path: Path) -> float:
25
  """
26
  Generate an image based on the prompt and save it to the specified path.
27
-
28
  Args:
29
  prompt (str): The text prompt to generate the image from
30
  save_path (Path): The path where the generated image should be saved
31
-
32
  Returns:
33
  float: The time taken for the API call in seconds
34
  """
 
14
  def name(self) -> str:
15
  """
16
  The name of the API implementation.
17
+
18
  Returns:
19
  str: The name of the specific API implementation
20
  """
 
24
  def generate_image(self, prompt: str, save_path: Path) -> float:
25
  """
26
  Generate an image based on the prompt and save it to the specified path.
27
+
28
  Args:
29
  prompt (str): The text prompt to generate the image from
30
  save_path (Path): The path where the generated image should be saved
31
+
32
  Returns:
33
  float: The time taken for the API call in seconds
34
  """
api/pruna.py CHANGED
@@ -3,8 +3,8 @@ import time
3
  from pathlib import Path
4
  from typing import Any
5
 
6
- from dotenv import load_dotenv
7
  import replicate
 
8
 
9
  from api.flux import FluxAPI
10
 
@@ -12,7 +12,9 @@ from api.flux import FluxAPI
12
  class PrunaAPI(FluxAPI):
13
  def __init__(self, speed_mode: str):
14
  self._speed_mode = speed_mode
15
- self._speed_mode_name = speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
 
 
16
  load_dotenv()
17
  self._api_key = os.getenv("REPLICATE_API_TOKEN")
18
  if not self._api_key:
@@ -38,7 +40,7 @@ class PrunaAPI(FluxAPI):
38
  },
39
  )
40
  end_time = time.time()
41
-
42
  if result:
43
  self._save_image_from_result(result, save_path)
44
  else:
 
3
  from pathlib import Path
4
  from typing import Any
5
 
 
6
  import replicate
7
+ from dotenv import load_dotenv
8
 
9
  from api.flux import FluxAPI
10
 
 
12
  class PrunaAPI(FluxAPI):
13
  def __init__(self, speed_mode: str):
14
  self._speed_mode = speed_mode
15
+ self._speed_mode_name = (
16
+ speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
17
+ )
18
  load_dotenv()
19
  self._api_key = os.getenv("REPLICATE_API_TOKEN")
20
  if not self._api_key:
 
40
  },
41
  )
42
  end_time = time.time()
43
+
44
  if result:
45
  self._save_image_from_result(result, save_path)
46
  else:
api/pruna_dev.py CHANGED
@@ -3,8 +3,8 @@ import time
3
  from pathlib import Path
4
  from typing import Any
5
 
6
- from dotenv import load_dotenv
7
  import replicate
 
8
 
9
  from api.flux import FluxAPI
10
 
@@ -36,7 +36,7 @@ class PrunaDevAPI(FluxAPI):
36
  },
37
  )
38
  end_time = time.time()
39
-
40
  if result:
41
  self._save_image_from_result(result, save_path)
42
  else:
@@ -46,4 +46,4 @@ class PrunaDevAPI(FluxAPI):
46
  def _save_image_from_result(self, result: Any, save_path: Path):
47
  save_path.parent.mkdir(parents=True, exist_ok=True)
48
  with open(save_path, "wb") as f:
49
- f.write(result.read())
 
3
  from pathlib import Path
4
  from typing import Any
5
 
 
6
  import replicate
7
+ from dotenv import load_dotenv
8
 
9
  from api.flux import FluxAPI
10
 
 
36
  },
37
  )
38
  end_time = time.time()
39
+
40
  if result:
41
  self._save_image_from_result(result, save_path)
42
  else:
 
46
  def _save_image_from_result(self, result: Any, save_path: Path):
47
  save_path.parent.mkdir(parents=True, exist_ok=True)
48
  with open(save_path, "wb") as f:
49
+ f.write(result.read())
api/replicate.py CHANGED
@@ -3,8 +3,8 @@ import time
3
  from pathlib import Path
4
  from typing import Any
5
 
6
- from dotenv import load_dotenv
7
  import replicate
 
8
 
9
  from api.flux import FluxAPI
10
 
 
3
  from pathlib import Path
4
  from typing import Any
5
 
 
6
  import replicate
7
+ from dotenv import load_dotenv
8
 
9
  from api.flux import FluxAPI
10
 
api/replicate_wan.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import replicate
7
+ from dotenv import load_dotenv
8
+
9
+ from api.flux import FluxAPI
10
+
11
+
12
+ class ReplicateAPI(FluxAPI):
13
+ def __init__(self):
14
+ load_dotenv()
15
+ self._api_key = os.getenv("REPLICATE_API_TOKEN")
16
+ if not self._api_key:
17
+ raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
18
+
19
+ @property
20
+ def name(self) -> str:
21
+ return "replicate_go_fast"
22
+
23
+ def generate_image(self, prompt: str, save_path: Path) -> float:
24
+ start_time = time.time()
25
+ result = replicate.run(
26
+ "black-forest-labs/flux-dev",
27
+ input={
28
+ "seed": 0,
29
+ "prompt": prompt,
30
+ "go_fast": True,
31
+ "guidance": 3.5,
32
+ "num_outputs": 1,
33
+ "aspect_ratio": "1:1",
34
+ "output_format": "png",
35
+ "num_inference_steps": 28,
36
+ },
37
+ )
38
+ end_time = time.time()
39
+ if result and len(result) > 0:
40
+ self._save_image_from_result(result[0], save_path)
41
+ else:
42
+ raise Exception("No result returned from Replicate API")
43
+ return end_time - start_time
44
+
45
+ def _save_image_from_result(self, result: Any, save_path: Path):
46
+ save_path.parent.mkdir(parents=True, exist_ok=True)
47
+ with open(save_path, "wb") as f:
48
+ f.write(result.read())
api/together.py CHANGED
@@ -33,10 +33,10 @@ class TogetherAPI(FluxAPI):
33
  response_format="b64_json",
34
  )
35
  end_time = time.time()
36
- if result and hasattr(result, 'data') and len(result.data) > 0:
37
  self._save_image_from_result(result, save_path)
38
  else:
39
- raise Exception("No result returned from Together API")
40
  return end_time - start_time
41
 
42
  def _save_image_from_result(self, result: Any, save_path: Path):
 
33
  response_format="b64_json",
34
  )
35
  end_time = time.time()
36
+ if result and hasattr(result, "data") and len(result.data) > 0:
37
  self._save_image_from_result(result, save_path)
38
  else:
39
+ raise Exception("No result returned from Together API")
40
  return end_time - start_time
41
 
42
  def _save_image_from_result(self, result: Any, save_path: Path):
benchmark/__init__.py CHANGED
@@ -7,10 +7,14 @@ from benchmark.hps import HPSPrompts
7
  from benchmark.parti import PartiPrompts
8
 
9
 
10
- def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts]:
 
 
 
 
11
  """
12
  Factory function to create benchmark instances.
13
-
14
  Args:
15
  benchmark_type (str): The type of benchmark to create. Must be one of:
16
  - "draw_bench"
@@ -18,10 +22,10 @@ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchP
18
  - "geneval"
19
  - "hps"
20
  - "parti"
21
-
22
  Returns:
23
  An instance of the requested benchmark implementation
24
-
25
  Raises:
26
  ValueError: If an invalid benchmark type is provided
27
  """
@@ -32,8 +36,10 @@ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchP
32
  "hps": HPSPrompts,
33
  "parti": PartiPrompts,
34
  }
35
-
36
  if benchmark_type not in benchmark_map:
37
- raise ValueError(f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}")
38
-
 
 
39
  return benchmark_map[benchmark_type]()
 
7
  from benchmark.parti import PartiPrompts
8
 
9
 
10
+ def create_benchmark(
11
+ benchmark_type: str,
12
+ ) -> Type[
13
+ DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts
14
+ ]:
15
  """
16
  Factory function to create benchmark instances.
17
+
18
  Args:
19
  benchmark_type (str): The type of benchmark to create. Must be one of:
20
  - "draw_bench"
 
22
  - "geneval"
23
  - "hps"
24
  - "parti"
25
+
26
  Returns:
27
  An instance of the requested benchmark implementation
28
+
29
  Raises:
30
  ValueError: If an invalid benchmark type is provided
31
  """
 
36
  "hps": HPSPrompts,
37
  "parti": PartiPrompts,
38
  }
39
+
40
  if benchmark_type not in benchmark_map:
41
+ raise ValueError(
42
+ f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}"
43
+ )
44
+
45
  return benchmark_map[benchmark_type]()
benchmark/genai_bench.py CHANGED
@@ -8,8 +8,8 @@ class GenAIBenchPrompts:
8
  def __init__(self):
9
  super().__init__()
10
  self._download_genai_bench_files()
11
- prompts_path = Path('downloads/genai_bench/prompts.txt')
12
- with open(prompts_path, 'r') as f:
13
  self.prompts = [line.strip() for line in f if line.strip()]
14
 
15
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
@@ -17,13 +17,13 @@ class GenAIBenchPrompts:
17
  yield prompt, Path(f"{i}.png")
18
 
19
  def _download_genai_bench_files(self) -> None:
20
- folder_name = Path('downloads/genai_bench')
21
  folder_name.mkdir(parents=True, exist_ok=True)
22
  prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
23
  prompts_path = folder_name / "prompts.txt"
24
  if not prompts_path.exists():
25
  response = requests.get(prompts_url)
26
- with open(prompts_path, 'w') as f:
27
  f.write(response.text)
28
 
29
  @property
 
8
  def __init__(self):
9
  super().__init__()
10
  self._download_genai_bench_files()
11
+ prompts_path = Path("downloads/genai_bench/prompts.txt")
12
+ with open(prompts_path, "r") as f:
13
  self.prompts = [line.strip() for line in f if line.strip()]
14
 
15
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
 
17
  yield prompt, Path(f"{i}.png")
18
 
19
  def _download_genai_bench_files(self) -> None:
20
+ folder_name = Path("downloads/genai_bench")
21
  folder_name.mkdir(parents=True, exist_ok=True)
22
  prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
23
  prompts_path = folder_name / "prompts.txt"
24
  if not prompts_path.exists():
25
  response = requests.get(prompts_url)
26
+ with open(prompts_path, "w") as f:
27
  f.write(response.text)
28
 
29
  @property
benchmark/geneval.py CHANGED
@@ -9,32 +9,32 @@ class GenEvalPrompts:
9
  def __init__(self):
10
  super().__init__()
11
  self._download_geneval_file()
12
- metadata_path = Path('downloads/geneval/evaluation_metadata.jsonl')
13
  self.entries: List[Dict[str, Any]] = []
14
- with open(metadata_path, 'r') as f:
15
  for line in f:
16
  if line.strip():
17
  self.entries.append(json.loads(line))
18
-
19
  def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
20
  for i, entry in enumerate(self.entries):
21
  folder_name = f"{i:05d}"
22
  yield entry, folder_name
23
 
24
  def _download_geneval_file(self) -> None:
25
- folder_name = Path('downloads/geneval')
26
  folder_name.mkdir(parents=True, exist_ok=True)
27
  metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
28
  metadata_path = folder_name / "evaluation_metadata.jsonl"
29
  if not metadata_path.exists():
30
  response = requests.get(metadata_url)
31
- with open(metadata_path, 'w') as f:
32
  f.write(response.text)
33
-
34
  @property
35
  def name(self) -> str:
36
  return "geneval"
37
-
38
  @property
39
  def size(self) -> int:
40
  return len(self.entries)
 
9
  def __init__(self):
10
  super().__init__()
11
  self._download_geneval_file()
12
+ metadata_path = Path("downloads/geneval/evaluation_metadata.jsonl")
13
  self.entries: List[Dict[str, Any]] = []
14
+ with open(metadata_path, "r") as f:
15
  for line in f:
16
  if line.strip():
17
  self.entries.append(json.loads(line))
18
+
19
  def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
20
  for i, entry in enumerate(self.entries):
21
  folder_name = f"{i:05d}"
22
  yield entry, folder_name
23
 
24
  def _download_geneval_file(self) -> None:
25
+ folder_name = Path("downloads/geneval")
26
  folder_name.mkdir(parents=True, exist_ok=True)
27
  metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
28
  metadata_path = folder_name / "evaluation_metadata.jsonl"
29
  if not metadata_path.exists():
30
  response = requests.get(metadata_url)
31
+ with open(metadata_path, "w") as f:
32
  f.write(response.text)
33
+
34
  @property
35
  def name(self) -> str:
36
  return "geneval"
37
+
38
  @property
39
  def size(self) -> int:
40
  return len(self.entries)
benchmark/hps.py CHANGED
@@ -9,13 +9,18 @@ import huggingface_hub
9
  class HPSPrompts:
10
  def __init__(self):
11
  super().__init__()
12
- self.hps_prompt_files = ['anime.json', 'concept-art.json', 'paintings.json', 'photo.json']
 
 
 
 
 
13
  self._download_benchmark_prompts()
14
  self.prompts: Dict[str, str] = {}
15
  self._size = 0
16
  for file in self.hps_prompt_files:
17
- category = file.replace('.json', '')
18
- with open(os.path.join('downloads/hps', file), 'r') as f:
19
  prompts = json.load(f)
20
  for i, prompt in enumerate(prompts):
21
  if i == 100:
@@ -23,24 +28,26 @@ class HPSPrompts:
23
  filename = f"{category}_{i:03d}.png"
24
  self.prompts[filename] = prompt
25
  self._size += 1
26
-
27
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
28
  for filename, prompt in self.prompts.items():
29
  yield prompt, Path(filename)
30
-
31
  @property
32
  def name(self) -> str:
33
  return "hps"
34
-
35
  @property
36
  def size(self) -> int:
37
  return self._size
38
 
39
  def _download_benchmark_prompts(self) -> None:
40
- folder_name = Path('downloads/hps')
41
  folder_name.mkdir(parents=True, exist_ok=True)
42
  for file in self.hps_prompt_files:
43
- file_name = huggingface_hub.hf_hub_download("zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset")
 
 
44
  if not os.path.exists(os.path.join(folder_name, file)):
45
  os.symlink(file_name, os.path.join(folder_name, file))
46
 
 
9
  class HPSPrompts:
10
  def __init__(self):
11
  super().__init__()
12
+ self.hps_prompt_files = [
13
+ "anime.json",
14
+ "concept-art.json",
15
+ "paintings.json",
16
+ "photo.json",
17
+ ]
18
  self._download_benchmark_prompts()
19
  self.prompts: Dict[str, str] = {}
20
  self._size = 0
21
  for file in self.hps_prompt_files:
22
+ category = file.replace(".json", "")
23
+ with open(os.path.join("downloads/hps", file), "r") as f:
24
  prompts = json.load(f)
25
  for i, prompt in enumerate(prompts):
26
  if i == 100:
 
28
  filename = f"{category}_{i:03d}.png"
29
  self.prompts[filename] = prompt
30
  self._size += 1
31
+
32
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
33
  for filename, prompt in self.prompts.items():
34
  yield prompt, Path(filename)
35
+
36
  @property
37
  def name(self) -> str:
38
  return "hps"
39
+
40
  @property
41
  def size(self) -> int:
42
  return self._size
43
 
44
  def _download_benchmark_prompts(self) -> None:
45
+ folder_name = Path("downloads/hps")
46
  folder_name.mkdir(parents=True, exist_ok=True)
47
  for file in self.hps_prompt_files:
48
+ file_name = huggingface_hub.hf_hub_download(
49
+ "zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset"
50
+ )
51
  if not os.path.exists(os.path.join(folder_name, file)):
52
  os.symlink(file_name, os.path.join(folder_name, file))
53
 
benchmark/parti.py CHANGED
@@ -14,11 +14,11 @@ class PartiPrompts:
14
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
15
  for i, prompt in enumerate(self.prompts):
16
  yield prompt, Path(f"{i}.png")
17
-
18
  @property
19
  def name(self) -> str:
20
  return "parti"
21
-
22
  @property
23
  def size(self) -> int:
24
  return len(self.prompts)
 
14
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
15
  for i, prompt in enumerate(self.prompts):
16
  yield prompt, Path(f"{i}.png")
17
+
18
  @property
19
  def name(self) -> str:
20
  return "parti"
21
+
22
  @property
23
  def size(self) -> int:
24
  return len(self.prompts)
evaluate.py CHANGED
@@ -1,59 +1,65 @@
1
  import argparse
2
  import json
 
3
  from pathlib import Path
4
  from typing import Dict
5
- import warnings
6
 
7
- from benchmark import create_benchmark
8
- from benchmark.metrics import create_metric
9
  import numpy as np
10
  from PIL import Image
11
  from tqdm import tqdm
12
 
 
 
13
 
14
  warnings.filterwarnings("ignore", category=FutureWarning)
15
 
16
 
17
- def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
 
 
18
  """
19
  Evaluate a benchmark's images using its specific metrics.
20
-
21
  Args:
22
  benchmark_type (str): Type of benchmark to evaluate
23
  api_type (str): Type of API used to generate images
24
  images_dir (Path): Base directory containing generated images
25
-
26
  Returns:
27
  Dict containing evaluation results
28
  """
29
  benchmark = create_benchmark(benchmark_type)
30
-
31
  benchmark_dir = images_dir / api_type / benchmark_type
32
  metadata_file = benchmark_dir / "metadata.jsonl"
33
-
34
  if not metadata_file.exists():
35
- raise FileNotFoundError(f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first.")
36
-
 
 
37
  metadata = []
38
  with open(metadata_file, "r") as f:
39
  for line in f:
40
  metadata.append(json.loads(line))
41
-
42
- metrics = {metric_type: create_metric(metric_type) for metric_type in benchmark.metrics}
43
-
 
 
44
  results = {
45
  "api": api_type,
46
  "benchmark": benchmark_type,
47
  "metrics": {metric: 0.0 for metric in benchmark.metrics},
48
- "total_images": len(metadata)
49
  }
50
  inference_times = []
51
-
52
  for entry in tqdm(metadata):
53
  image_path = benchmark_dir / entry["filepath"]
54
  if not image_path.exists():
55
  continue
56
-
57
  for metric_type, metric in metrics.items():
58
  try:
59
  if metric_type == "vqa":
@@ -64,26 +70,30 @@ def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Pa
64
  results["metrics"][metric_type] += score[metric_type]
65
  except Exception as e:
66
  print(f"Error computing {metric_type} for {image_path}: {str(e)}")
67
-
68
  inference_times.append(entry["inference_time"])
69
-
70
  for metric in results["metrics"]:
71
  results["metrics"][metric] /= len(metadata)
72
  results["median_inference_time"] = np.median(inference_times).item()
73
-
74
  return results
75
 
76
 
77
  def main():
78
- parser = argparse.ArgumentParser(description="Evaluate generated images using benchmark-specific metrics")
 
 
79
  parser.add_argument("api_type", help="Type of API to evaluate")
80
- parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to evaluate")
81
-
 
 
82
  args = parser.parse_args()
83
-
84
  results_dir = Path("evaluation_results")
85
  results_dir.mkdir(exist_ok=True)
86
-
87
  results_file = results_dir / f"{args.api_type}.jsonl"
88
  existing_results = set()
89
 
@@ -97,15 +107,15 @@ def main():
97
  if benchmark_type in existing_results:
98
  print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
99
  continue
100
-
101
  try:
102
  print(f"Evaluating {args.api_type}/{benchmark_type}")
103
  results = evaluate_benchmark(benchmark_type, args.api_type)
104
-
105
  # Append results to file
106
  with open(results_file, "a") as f:
107
  f.write(json.dumps(results) + "\n")
108
-
109
  except Exception as e:
110
  print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
111
 
 
1
  import argparse
2
  import json
3
+ import warnings
4
  from pathlib import Path
5
  from typing import Dict
 
6
 
 
 
7
  import numpy as np
8
  from PIL import Image
9
  from tqdm import tqdm
10
 
11
+ from benchmark import create_benchmark
12
+ from benchmark.metrics import create_metric
13
 
14
  warnings.filterwarnings("ignore", category=FutureWarning)
15
 
16
 
17
+ def evaluate_benchmark(
18
+ benchmark_type: str, api_type: str, images_dir: Path = Path("images")
19
+ ) -> Dict:
20
  """
21
  Evaluate a benchmark's images using its specific metrics.
22
+
23
  Args:
24
  benchmark_type (str): Type of benchmark to evaluate
25
  api_type (str): Type of API used to generate images
26
  images_dir (Path): Base directory containing generated images
27
+
28
  Returns:
29
  Dict containing evaluation results
30
  """
31
  benchmark = create_benchmark(benchmark_type)
32
+
33
  benchmark_dir = images_dir / api_type / benchmark_type
34
  metadata_file = benchmark_dir / "metadata.jsonl"
35
+
36
  if not metadata_file.exists():
37
+ raise FileNotFoundError(
38
+ f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first."
39
+ )
40
+
41
  metadata = []
42
  with open(metadata_file, "r") as f:
43
  for line in f:
44
  metadata.append(json.loads(line))
45
+
46
+ metrics = {
47
+ metric_type: create_metric(metric_type) for metric_type in benchmark.metrics
48
+ }
49
+
50
  results = {
51
  "api": api_type,
52
  "benchmark": benchmark_type,
53
  "metrics": {metric: 0.0 for metric in benchmark.metrics},
54
+ "total_images": len(metadata),
55
  }
56
  inference_times = []
57
+
58
  for entry in tqdm(metadata):
59
  image_path = benchmark_dir / entry["filepath"]
60
  if not image_path.exists():
61
  continue
62
+
63
  for metric_type, metric in metrics.items():
64
  try:
65
  if metric_type == "vqa":
 
70
  results["metrics"][metric_type] += score[metric_type]
71
  except Exception as e:
72
  print(f"Error computing {metric_type} for {image_path}: {str(e)}")
73
+
74
  inference_times.append(entry["inference_time"])
75
+
76
  for metric in results["metrics"]:
77
  results["metrics"][metric] /= len(metadata)
78
  results["median_inference_time"] = np.median(inference_times).item()
79
+
80
  return results
81
 
82
 
83
  def main():
84
+ parser = argparse.ArgumentParser(
85
+ description="Evaluate generated images using benchmark-specific metrics"
86
+ )
87
  parser.add_argument("api_type", help="Type of API to evaluate")
88
+ parser.add_argument(
89
+ "benchmarks", nargs="+", help="List of benchmark types to evaluate"
90
+ )
91
+
92
  args = parser.parse_args()
93
+
94
  results_dir = Path("evaluation_results")
95
  results_dir.mkdir(exist_ok=True)
96
+
97
  results_file = results_dir / f"{args.api_type}.jsonl"
98
  existing_results = set()
99
 
 
107
  if benchmark_type in existing_results:
108
  print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
109
  continue
110
+
111
  try:
112
  print(f"Evaluating {args.api_type}/{benchmark_type}")
113
  results = evaluate_benchmark(benchmark_type, args.api_type)
114
+
115
  # Append results to file
116
  with open(results_file, "a") as f:
117
  f.write(json.dumps(results) + "\n")
118
+
119
  except Exception as e:
120
  print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
121