Spaces:
Running
Running
Commit
·
34046e2
1
Parent(s):
9fa4df6
refactor: improve code formatting and organization across multiple API and benchmark files
Browse files- api/__init__.py +21 -18
- api/aws.py +8 -12
- api/baseline.py +5 -3
- api/fal.py +2 -2
- api/fireworks.py +3 -3
- api/flux.py +3 -3
- api/pruna.py +5 -3
- api/pruna_dev.py +3 -3
- api/replicate.py +1 -1
- api/replicate_wan.py +48 -0
- api/together.py +2 -2
- benchmark/__init__.py +13 -7
- benchmark/genai_bench.py +4 -4
- benchmark/geneval.py +7 -7
- benchmark/hps.py +15 -8
- benchmark/parti.py +2 -2
- evaluate.py +37 -27
api/__init__.py
CHANGED
@@ -1,31 +1,32 @@
|
|
1 |
from typing import Type
|
2 |
|
|
|
3 |
from api.baseline import BaselineAPI
|
|
|
4 |
from api.fireworks import FireworksAPI
|
5 |
from api.flux import FluxAPI
|
6 |
from api.pruna import PrunaAPI
|
7 |
from api.pruna_dev import PrunaDevAPI
|
8 |
from api.replicate import ReplicateAPI
|
9 |
from api.together import TogetherAPI
|
10 |
-
from api.fal import FalAPI
|
11 |
-
from api.aws import AWSBedrockAPI
|
12 |
|
13 |
__all__ = [
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
]
|
24 |
|
|
|
25 |
def create_api(api_type: str) -> FluxAPI:
|
26 |
"""
|
27 |
Factory function to create API instances.
|
28 |
-
|
29 |
Args:
|
30 |
api_type (str): The type of API to create. Must be one of:
|
31 |
- "baseline"
|
@@ -35,10 +36,10 @@ def create_api(api_type: str) -> FluxAPI:
|
|
35 |
- "together"
|
36 |
- "fal"
|
37 |
- "aws"
|
38 |
-
|
39 |
Returns:
|
40 |
FluxAPI: An instance of the requested API implementation
|
41 |
-
|
42 |
Raises:
|
43 |
ValueError: If an invalid API type is provided
|
44 |
"""
|
@@ -47,7 +48,7 @@ def create_api(api_type: str) -> FluxAPI:
|
|
47 |
if api_type.startswith("pruna_"):
|
48 |
speed_mode = api_type[6:] # Remove "pruna_" prefix
|
49 |
return PrunaAPI(speed_mode)
|
50 |
-
|
51 |
api_map: dict[str, Type[FluxAPI]] = {
|
52 |
"baseline": BaselineAPI,
|
53 |
"fireworks": FireworksAPI,
|
@@ -56,8 +57,10 @@ def create_api(api_type: str) -> FluxAPI:
|
|
56 |
"fal": FalAPI,
|
57 |
"aws": AWSBedrockAPI,
|
58 |
}
|
59 |
-
|
60 |
if api_type not in api_map:
|
61 |
-
raise ValueError(
|
62 |
-
|
|
|
|
|
63 |
return api_map[api_type]()
|
|
|
1 |
from typing import Type
|
2 |
|
3 |
+
from api.aws import AWSBedrockAPI
|
4 |
from api.baseline import BaselineAPI
|
5 |
+
from api.fal import FalAPI
|
6 |
from api.fireworks import FireworksAPI
|
7 |
from api.flux import FluxAPI
|
8 |
from api.pruna import PrunaAPI
|
9 |
from api.pruna_dev import PrunaDevAPI
|
10 |
from api.replicate import ReplicateAPI
|
11 |
from api.together import TogetherAPI
|
|
|
|
|
12 |
|
13 |
__all__ = [
|
14 |
+
"create_api",
|
15 |
+
"FluxAPI",
|
16 |
+
"BaselineAPI",
|
17 |
+
"FireworksAPI",
|
18 |
+
"PrunaAPI",
|
19 |
+
"ReplicateAPI",
|
20 |
+
"TogetherAPI",
|
21 |
+
"FalAPI",
|
22 |
+
"PrunaDevAPI",
|
23 |
]
|
24 |
|
25 |
+
|
26 |
def create_api(api_type: str) -> FluxAPI:
|
27 |
"""
|
28 |
Factory function to create API instances.
|
29 |
+
|
30 |
Args:
|
31 |
api_type (str): The type of API to create. Must be one of:
|
32 |
- "baseline"
|
|
|
36 |
- "together"
|
37 |
- "fal"
|
38 |
- "aws"
|
39 |
+
|
40 |
Returns:
|
41 |
FluxAPI: An instance of the requested API implementation
|
42 |
+
|
43 |
Raises:
|
44 |
ValueError: If an invalid API type is provided
|
45 |
"""
|
|
|
48 |
if api_type.startswith("pruna_"):
|
49 |
speed_mode = api_type[6:] # Remove "pruna_" prefix
|
50 |
return PrunaAPI(speed_mode)
|
51 |
+
|
52 |
api_map: dict[str, Type[FluxAPI]] = {
|
53 |
"baseline": BaselineAPI,
|
54 |
"fireworks": FireworksAPI,
|
|
|
57 |
"fal": FalAPI,
|
58 |
"aws": AWSBedrockAPI,
|
59 |
}
|
60 |
+
|
61 |
if api_type not in api_map:
|
62 |
+
raise ValueError(
|
63 |
+
f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'"
|
64 |
+
)
|
65 |
+
|
66 |
return api_map[api_type]()
|
api/aws.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
import os
|
2 |
-
import time
|
3 |
import base64
|
4 |
import json
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
-
from typing import Any
|
7 |
|
8 |
import boto3
|
9 |
from dotenv import load_dotenv
|
@@ -45,23 +44,20 @@ class AWSBedrockAPI(FluxAPI):
|
|
45 |
try:
|
46 |
# Convert request to JSON and invoke the model
|
47 |
request = json.dumps(native_request)
|
48 |
-
response = self._client.invoke_model(
|
49 |
-
|
50 |
-
body=request
|
51 |
-
)
|
52 |
-
|
53 |
# Process the response
|
54 |
model_response = json.loads(response["body"].read())
|
55 |
if not model_response.get("images"):
|
56 |
raise Exception("No images returned from AWS Bedrock API")
|
57 |
-
|
58 |
# Save the image
|
59 |
base64_image_data = model_response["images"][0]
|
60 |
self._save_image_from_base64(base64_image_data, save_path)
|
61 |
-
|
62 |
except Exception as e:
|
63 |
raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
|
64 |
-
|
65 |
end_time = time.time()
|
66 |
return end_time - start_time
|
67 |
|
@@ -70,4 +66,4 @@ class AWSBedrockAPI(FluxAPI):
|
|
70 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
71 |
image_data = base64.b64decode(base64_data)
|
72 |
with open(save_path, "wb") as f:
|
73 |
-
f.write(image_data)
|
|
|
|
|
|
|
1 |
import base64
|
2 |
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
import boto3
|
8 |
from dotenv import load_dotenv
|
|
|
44 |
try:
|
45 |
# Convert request to JSON and invoke the model
|
46 |
request = json.dumps(native_request)
|
47 |
+
response = self._client.invoke_model(modelId=self._model_id, body=request)
|
48 |
+
|
|
|
|
|
|
|
49 |
# Process the response
|
50 |
model_response = json.loads(response["body"].read())
|
51 |
if not model_response.get("images"):
|
52 |
raise Exception("No images returned from AWS Bedrock API")
|
53 |
+
|
54 |
# Save the image
|
55 |
base64_image_data = model_response["images"][0]
|
56 |
self._save_image_from_base64(base64_image_data, save_path)
|
57 |
+
|
58 |
except Exception as e:
|
59 |
raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
|
60 |
+
|
61 |
end_time = time.time()
|
62 |
return end_time - start_time
|
63 |
|
|
|
66 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
67 |
image_data = base64.b64decode(base64_data)
|
68 |
with open(save_path, "wb") as f:
|
69 |
+
f.write(image_data)
|
api/baseline.py
CHANGED
@@ -12,6 +12,7 @@ class BaselineAPI(FluxAPI):
|
|
12 |
"""
|
13 |
As our baseline, we use the Replicate API with go_fast=False.
|
14 |
"""
|
|
|
15 |
def __init__(self):
|
16 |
load_dotenv()
|
17 |
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
@@ -24,6 +25,7 @@ class BaselineAPI(FluxAPI):
|
|
24 |
|
25 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
26 |
import replicate
|
|
|
27 |
start_time = time.time()
|
28 |
result = replicate.run(
|
29 |
"black-forest-labs/flux-dev",
|
@@ -39,15 +41,15 @@ class BaselineAPI(FluxAPI):
|
|
39 |
},
|
40 |
)
|
41 |
end_time = time.time()
|
42 |
-
|
43 |
if result and len(result) > 0:
|
44 |
self._save_image_from_result(result[0], save_path)
|
45 |
else:
|
46 |
raise Exception("No result returned from Replicate API")
|
47 |
-
|
48 |
return end_time - start_time
|
49 |
|
50 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
51 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
52 |
with open(save_path, "wb") as f:
|
53 |
-
f.write(result.read())
|
|
|
12 |
"""
|
13 |
As our baseline, we use the Replicate API with go_fast=False.
|
14 |
"""
|
15 |
+
|
16 |
def __init__(self):
|
17 |
load_dotenv()
|
18 |
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
|
|
25 |
|
26 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
27 |
import replicate
|
28 |
+
|
29 |
start_time = time.time()
|
30 |
result = replicate.run(
|
31 |
"black-forest-labs/flux-dev",
|
|
|
41 |
},
|
42 |
)
|
43 |
end_time = time.time()
|
44 |
+
|
45 |
if result and len(result) > 0:
|
46 |
self._save_image_from_result(result[0], save_path)
|
47 |
else:
|
48 |
raise Exception("No result returned from Replicate API")
|
49 |
+
|
50 |
return end_time - start_time
|
51 |
|
52 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
53 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
54 |
with open(save_path, "wb") as f:
|
55 |
+
f.write(result.read())
|
api/fal.py
CHANGED
@@ -30,10 +30,10 @@ class FalAPI(FluxAPI):
|
|
30 |
},
|
31 |
)
|
32 |
end_time = time.time()
|
33 |
-
|
34 |
url = result["images"][0]["url"]
|
35 |
self._save_image_from_url(url, save_path)
|
36 |
-
|
37 |
return end_time - start_time
|
38 |
|
39 |
def _save_image_from_url(self, url: str, save_path: Path):
|
|
|
30 |
},
|
31 |
)
|
32 |
end_time = time.time()
|
33 |
+
|
34 |
url = result["images"][0]["url"]
|
35 |
self._save_image_from_url(url, save_path)
|
36 |
+
|
37 |
return end_time - start_time
|
38 |
|
39 |
def _save_image_from_url(self, url: str, save_path: Path):
|
api/fireworks.py
CHANGED
@@ -23,7 +23,7 @@ class FireworksAPI(FluxAPI):
|
|
23 |
|
24 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
start_time = time.time()
|
26 |
-
|
27 |
headers = {
|
28 |
"Content-Type": "application/json",
|
29 |
"Accept": "image/jpeg",
|
@@ -39,12 +39,12 @@ class FireworksAPI(FluxAPI):
|
|
39 |
result = requests.post(self._url, headers=headers, json=data)
|
40 |
|
41 |
end_time = time.time()
|
42 |
-
|
43 |
if result.status_code == 200:
|
44 |
self._save_image_from_result(result, save_path)
|
45 |
else:
|
46 |
raise Exception(f"Error: {result.status_code} {result.text}")
|
47 |
-
|
48 |
return end_time - start_time
|
49 |
|
50 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
|
|
23 |
|
24 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
start_time = time.time()
|
26 |
+
|
27 |
headers = {
|
28 |
"Content-Type": "application/json",
|
29 |
"Accept": "image/jpeg",
|
|
|
39 |
result = requests.post(self._url, headers=headers, json=data)
|
40 |
|
41 |
end_time = time.time()
|
42 |
+
|
43 |
if result.status_code == 200:
|
44 |
self._save_image_from_result(result, save_path)
|
45 |
else:
|
46 |
raise Exception(f"Error: {result.status_code} {result.text}")
|
47 |
+
|
48 |
return end_time - start_time
|
49 |
|
50 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
api/flux.py
CHANGED
@@ -14,7 +14,7 @@ class FluxAPI(ABC):
|
|
14 |
def name(self) -> str:
|
15 |
"""
|
16 |
The name of the API implementation.
|
17 |
-
|
18 |
Returns:
|
19 |
str: The name of the specific API implementation
|
20 |
"""
|
@@ -24,11 +24,11 @@ class FluxAPI(ABC):
|
|
24 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
"""
|
26 |
Generate an image based on the prompt and save it to the specified path.
|
27 |
-
|
28 |
Args:
|
29 |
prompt (str): The text prompt to generate the image from
|
30 |
save_path (Path): The path where the generated image should be saved
|
31 |
-
|
32 |
Returns:
|
33 |
float: The time taken for the API call in seconds
|
34 |
"""
|
|
|
14 |
def name(self) -> str:
|
15 |
"""
|
16 |
The name of the API implementation.
|
17 |
+
|
18 |
Returns:
|
19 |
str: The name of the specific API implementation
|
20 |
"""
|
|
|
24 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
"""
|
26 |
Generate an image based on the prompt and save it to the specified path.
|
27 |
+
|
28 |
Args:
|
29 |
prompt (str): The text prompt to generate the image from
|
30 |
save_path (Path): The path where the generated image should be saved
|
31 |
+
|
32 |
Returns:
|
33 |
float: The time taken for the API call in seconds
|
34 |
"""
|
api/pruna.py
CHANGED
@@ -3,8 +3,8 @@ import time
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
6 |
-
from dotenv import load_dotenv
|
7 |
import replicate
|
|
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
@@ -12,7 +12,9 @@ from api.flux import FluxAPI
|
|
12 |
class PrunaAPI(FluxAPI):
|
13 |
def __init__(self, speed_mode: str):
|
14 |
self._speed_mode = speed_mode
|
15 |
-
self._speed_mode_name =
|
|
|
|
|
16 |
load_dotenv()
|
17 |
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
18 |
if not self._api_key:
|
@@ -38,7 +40,7 @@ class PrunaAPI(FluxAPI):
|
|
38 |
},
|
39 |
)
|
40 |
end_time = time.time()
|
41 |
-
|
42 |
if result:
|
43 |
self._save_image_from_result(result, save_path)
|
44 |
else:
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
|
|
6 |
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
|
|
12 |
class PrunaAPI(FluxAPI):
|
13 |
def __init__(self, speed_mode: str):
|
14 |
self._speed_mode = speed_mode
|
15 |
+
self._speed_mode_name = (
|
16 |
+
speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
|
17 |
+
)
|
18 |
load_dotenv()
|
19 |
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
20 |
if not self._api_key:
|
|
|
40 |
},
|
41 |
)
|
42 |
end_time = time.time()
|
43 |
+
|
44 |
if result:
|
45 |
self._save_image_from_result(result, save_path)
|
46 |
else:
|
api/pruna_dev.py
CHANGED
@@ -3,8 +3,8 @@ import time
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
6 |
-
from dotenv import load_dotenv
|
7 |
import replicate
|
|
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
@@ -36,7 +36,7 @@ class PrunaDevAPI(FluxAPI):
|
|
36 |
},
|
37 |
)
|
38 |
end_time = time.time()
|
39 |
-
|
40 |
if result:
|
41 |
self._save_image_from_result(result, save_path)
|
42 |
else:
|
@@ -46,4 +46,4 @@ class PrunaDevAPI(FluxAPI):
|
|
46 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
47 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
48 |
with open(save_path, "wb") as f:
|
49 |
-
f.write(result.read())
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
|
|
6 |
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
|
|
36 |
},
|
37 |
)
|
38 |
end_time = time.time()
|
39 |
+
|
40 |
if result:
|
41 |
self._save_image_from_result(result, save_path)
|
42 |
else:
|
|
|
46 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
47 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
48 |
with open(save_path, "wb") as f:
|
49 |
+
f.write(result.read())
|
api/replicate.py
CHANGED
@@ -3,8 +3,8 @@ import time
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
6 |
-
from dotenv import load_dotenv
|
7 |
import replicate
|
|
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
|
|
6 |
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
api/replicate_wan.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
from api.flux import FluxAPI
|
10 |
+
|
11 |
+
|
12 |
+
class ReplicateAPI(FluxAPI):
|
13 |
+
def __init__(self):
|
14 |
+
load_dotenv()
|
15 |
+
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
16 |
+
if not self._api_key:
|
17 |
+
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
18 |
+
|
19 |
+
@property
|
20 |
+
def name(self) -> str:
|
21 |
+
return "replicate_go_fast"
|
22 |
+
|
23 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
24 |
+
start_time = time.time()
|
25 |
+
result = replicate.run(
|
26 |
+
"black-forest-labs/flux-dev",
|
27 |
+
input={
|
28 |
+
"seed": 0,
|
29 |
+
"prompt": prompt,
|
30 |
+
"go_fast": True,
|
31 |
+
"guidance": 3.5,
|
32 |
+
"num_outputs": 1,
|
33 |
+
"aspect_ratio": "1:1",
|
34 |
+
"output_format": "png",
|
35 |
+
"num_inference_steps": 28,
|
36 |
+
},
|
37 |
+
)
|
38 |
+
end_time = time.time()
|
39 |
+
if result and len(result) > 0:
|
40 |
+
self._save_image_from_result(result[0], save_path)
|
41 |
+
else:
|
42 |
+
raise Exception("No result returned from Replicate API")
|
43 |
+
return end_time - start_time
|
44 |
+
|
45 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
46 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
47 |
+
with open(save_path, "wb") as f:
|
48 |
+
f.write(result.read())
|
api/together.py
CHANGED
@@ -33,10 +33,10 @@ class TogetherAPI(FluxAPI):
|
|
33 |
response_format="b64_json",
|
34 |
)
|
35 |
end_time = time.time()
|
36 |
-
if result and hasattr(result,
|
37 |
self._save_image_from_result(result, save_path)
|
38 |
else:
|
39 |
-
raise Exception("No result returned from Together API")
|
40 |
return end_time - start_time
|
41 |
|
42 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
|
|
33 |
response_format="b64_json",
|
34 |
)
|
35 |
end_time = time.time()
|
36 |
+
if result and hasattr(result, "data") and len(result.data) > 0:
|
37 |
self._save_image_from_result(result, save_path)
|
38 |
else:
|
39 |
+
raise Exception("No result returned from Together API")
|
40 |
return end_time - start_time
|
41 |
|
42 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
benchmark/__init__.py
CHANGED
@@ -7,10 +7,14 @@ from benchmark.hps import HPSPrompts
|
|
7 |
from benchmark.parti import PartiPrompts
|
8 |
|
9 |
|
10 |
-
def create_benchmark(
|
|
|
|
|
|
|
|
|
11 |
"""
|
12 |
Factory function to create benchmark instances.
|
13 |
-
|
14 |
Args:
|
15 |
benchmark_type (str): The type of benchmark to create. Must be one of:
|
16 |
- "draw_bench"
|
@@ -18,10 +22,10 @@ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchP
|
|
18 |
- "geneval"
|
19 |
- "hps"
|
20 |
- "parti"
|
21 |
-
|
22 |
Returns:
|
23 |
An instance of the requested benchmark implementation
|
24 |
-
|
25 |
Raises:
|
26 |
ValueError: If an invalid benchmark type is provided
|
27 |
"""
|
@@ -32,8 +36,10 @@ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchP
|
|
32 |
"hps": HPSPrompts,
|
33 |
"parti": PartiPrompts,
|
34 |
}
|
35 |
-
|
36 |
if benchmark_type not in benchmark_map:
|
37 |
-
raise ValueError(
|
38 |
-
|
|
|
|
|
39 |
return benchmark_map[benchmark_type]()
|
|
|
7 |
from benchmark.parti import PartiPrompts
|
8 |
|
9 |
|
10 |
+
def create_benchmark(
|
11 |
+
benchmark_type: str,
|
12 |
+
) -> Type[
|
13 |
+
DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts
|
14 |
+
]:
|
15 |
"""
|
16 |
Factory function to create benchmark instances.
|
17 |
+
|
18 |
Args:
|
19 |
benchmark_type (str): The type of benchmark to create. Must be one of:
|
20 |
- "draw_bench"
|
|
|
22 |
- "geneval"
|
23 |
- "hps"
|
24 |
- "parti"
|
25 |
+
|
26 |
Returns:
|
27 |
An instance of the requested benchmark implementation
|
28 |
+
|
29 |
Raises:
|
30 |
ValueError: If an invalid benchmark type is provided
|
31 |
"""
|
|
|
36 |
"hps": HPSPrompts,
|
37 |
"parti": PartiPrompts,
|
38 |
}
|
39 |
+
|
40 |
if benchmark_type not in benchmark_map:
|
41 |
+
raise ValueError(
|
42 |
+
f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}"
|
43 |
+
)
|
44 |
+
|
45 |
return benchmark_map[benchmark_type]()
|
benchmark/genai_bench.py
CHANGED
@@ -8,8 +8,8 @@ class GenAIBenchPrompts:
|
|
8 |
def __init__(self):
|
9 |
super().__init__()
|
10 |
self._download_genai_bench_files()
|
11 |
-
prompts_path = Path(
|
12 |
-
with open(prompts_path,
|
13 |
self.prompts = [line.strip() for line in f if line.strip()]
|
14 |
|
15 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
@@ -17,13 +17,13 @@ class GenAIBenchPrompts:
|
|
17 |
yield prompt, Path(f"{i}.png")
|
18 |
|
19 |
def _download_genai_bench_files(self) -> None:
|
20 |
-
folder_name = Path(
|
21 |
folder_name.mkdir(parents=True, exist_ok=True)
|
22 |
prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
|
23 |
prompts_path = folder_name / "prompts.txt"
|
24 |
if not prompts_path.exists():
|
25 |
response = requests.get(prompts_url)
|
26 |
-
with open(prompts_path,
|
27 |
f.write(response.text)
|
28 |
|
29 |
@property
|
|
|
8 |
def __init__(self):
|
9 |
super().__init__()
|
10 |
self._download_genai_bench_files()
|
11 |
+
prompts_path = Path("downloads/genai_bench/prompts.txt")
|
12 |
+
with open(prompts_path, "r") as f:
|
13 |
self.prompts = [line.strip() for line in f if line.strip()]
|
14 |
|
15 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
|
|
17 |
yield prompt, Path(f"{i}.png")
|
18 |
|
19 |
def _download_genai_bench_files(self) -> None:
|
20 |
+
folder_name = Path("downloads/genai_bench")
|
21 |
folder_name.mkdir(parents=True, exist_ok=True)
|
22 |
prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
|
23 |
prompts_path = folder_name / "prompts.txt"
|
24 |
if not prompts_path.exists():
|
25 |
response = requests.get(prompts_url)
|
26 |
+
with open(prompts_path, "w") as f:
|
27 |
f.write(response.text)
|
28 |
|
29 |
@property
|
benchmark/geneval.py
CHANGED
@@ -9,32 +9,32 @@ class GenEvalPrompts:
|
|
9 |
def __init__(self):
|
10 |
super().__init__()
|
11 |
self._download_geneval_file()
|
12 |
-
metadata_path = Path(
|
13 |
self.entries: List[Dict[str, Any]] = []
|
14 |
-
with open(metadata_path,
|
15 |
for line in f:
|
16 |
if line.strip():
|
17 |
self.entries.append(json.loads(line))
|
18 |
-
|
19 |
def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
|
20 |
for i, entry in enumerate(self.entries):
|
21 |
folder_name = f"{i:05d}"
|
22 |
yield entry, folder_name
|
23 |
|
24 |
def _download_geneval_file(self) -> None:
|
25 |
-
folder_name = Path(
|
26 |
folder_name.mkdir(parents=True, exist_ok=True)
|
27 |
metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
|
28 |
metadata_path = folder_name / "evaluation_metadata.jsonl"
|
29 |
if not metadata_path.exists():
|
30 |
response = requests.get(metadata_url)
|
31 |
-
with open(metadata_path,
|
32 |
f.write(response.text)
|
33 |
-
|
34 |
@property
|
35 |
def name(self) -> str:
|
36 |
return "geneval"
|
37 |
-
|
38 |
@property
|
39 |
def size(self) -> int:
|
40 |
return len(self.entries)
|
|
|
9 |
def __init__(self):
|
10 |
super().__init__()
|
11 |
self._download_geneval_file()
|
12 |
+
metadata_path = Path("downloads/geneval/evaluation_metadata.jsonl")
|
13 |
self.entries: List[Dict[str, Any]] = []
|
14 |
+
with open(metadata_path, "r") as f:
|
15 |
for line in f:
|
16 |
if line.strip():
|
17 |
self.entries.append(json.loads(line))
|
18 |
+
|
19 |
def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
|
20 |
for i, entry in enumerate(self.entries):
|
21 |
folder_name = f"{i:05d}"
|
22 |
yield entry, folder_name
|
23 |
|
24 |
def _download_geneval_file(self) -> None:
|
25 |
+
folder_name = Path("downloads/geneval")
|
26 |
folder_name.mkdir(parents=True, exist_ok=True)
|
27 |
metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
|
28 |
metadata_path = folder_name / "evaluation_metadata.jsonl"
|
29 |
if not metadata_path.exists():
|
30 |
response = requests.get(metadata_url)
|
31 |
+
with open(metadata_path, "w") as f:
|
32 |
f.write(response.text)
|
33 |
+
|
34 |
@property
|
35 |
def name(self) -> str:
|
36 |
return "geneval"
|
37 |
+
|
38 |
@property
|
39 |
def size(self) -> int:
|
40 |
return len(self.entries)
|
benchmark/hps.py
CHANGED
@@ -9,13 +9,18 @@ import huggingface_hub
|
|
9 |
class HPSPrompts:
|
10 |
def __init__(self):
|
11 |
super().__init__()
|
12 |
-
self.hps_prompt_files = [
|
|
|
|
|
|
|
|
|
|
|
13 |
self._download_benchmark_prompts()
|
14 |
self.prompts: Dict[str, str] = {}
|
15 |
self._size = 0
|
16 |
for file in self.hps_prompt_files:
|
17 |
-
category = file.replace(
|
18 |
-
with open(os.path.join(
|
19 |
prompts = json.load(f)
|
20 |
for i, prompt in enumerate(prompts):
|
21 |
if i == 100:
|
@@ -23,24 +28,26 @@ class HPSPrompts:
|
|
23 |
filename = f"{category}_{i:03d}.png"
|
24 |
self.prompts[filename] = prompt
|
25 |
self._size += 1
|
26 |
-
|
27 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
28 |
for filename, prompt in self.prompts.items():
|
29 |
yield prompt, Path(filename)
|
30 |
-
|
31 |
@property
|
32 |
def name(self) -> str:
|
33 |
return "hps"
|
34 |
-
|
35 |
@property
|
36 |
def size(self) -> int:
|
37 |
return self._size
|
38 |
|
39 |
def _download_benchmark_prompts(self) -> None:
|
40 |
-
folder_name = Path(
|
41 |
folder_name.mkdir(parents=True, exist_ok=True)
|
42 |
for file in self.hps_prompt_files:
|
43 |
-
file_name = huggingface_hub.hf_hub_download(
|
|
|
|
|
44 |
if not os.path.exists(os.path.join(folder_name, file)):
|
45 |
os.symlink(file_name, os.path.join(folder_name, file))
|
46 |
|
|
|
9 |
class HPSPrompts:
|
10 |
def __init__(self):
|
11 |
super().__init__()
|
12 |
+
self.hps_prompt_files = [
|
13 |
+
"anime.json",
|
14 |
+
"concept-art.json",
|
15 |
+
"paintings.json",
|
16 |
+
"photo.json",
|
17 |
+
]
|
18 |
self._download_benchmark_prompts()
|
19 |
self.prompts: Dict[str, str] = {}
|
20 |
self._size = 0
|
21 |
for file in self.hps_prompt_files:
|
22 |
+
category = file.replace(".json", "")
|
23 |
+
with open(os.path.join("downloads/hps", file), "r") as f:
|
24 |
prompts = json.load(f)
|
25 |
for i, prompt in enumerate(prompts):
|
26 |
if i == 100:
|
|
|
28 |
filename = f"{category}_{i:03d}.png"
|
29 |
self.prompts[filename] = prompt
|
30 |
self._size += 1
|
31 |
+
|
32 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
33 |
for filename, prompt in self.prompts.items():
|
34 |
yield prompt, Path(filename)
|
35 |
+
|
36 |
@property
|
37 |
def name(self) -> str:
|
38 |
return "hps"
|
39 |
+
|
40 |
@property
|
41 |
def size(self) -> int:
|
42 |
return self._size
|
43 |
|
44 |
def _download_benchmark_prompts(self) -> None:
|
45 |
+
folder_name = Path("downloads/hps")
|
46 |
folder_name.mkdir(parents=True, exist_ok=True)
|
47 |
for file in self.hps_prompt_files:
|
48 |
+
file_name = huggingface_hub.hf_hub_download(
|
49 |
+
"zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset"
|
50 |
+
)
|
51 |
if not os.path.exists(os.path.join(folder_name, file)):
|
52 |
os.symlink(file_name, os.path.join(folder_name, file))
|
53 |
|
benchmark/parti.py
CHANGED
@@ -14,11 +14,11 @@ class PartiPrompts:
|
|
14 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
15 |
for i, prompt in enumerate(self.prompts):
|
16 |
yield prompt, Path(f"{i}.png")
|
17 |
-
|
18 |
@property
|
19 |
def name(self) -> str:
|
20 |
return "parti"
|
21 |
-
|
22 |
@property
|
23 |
def size(self) -> int:
|
24 |
return len(self.prompts)
|
|
|
14 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
15 |
for i, prompt in enumerate(self.prompts):
|
16 |
yield prompt, Path(f"{i}.png")
|
17 |
+
|
18 |
@property
|
19 |
def name(self) -> str:
|
20 |
return "parti"
|
21 |
+
|
22 |
@property
|
23 |
def size(self) -> int:
|
24 |
return len(self.prompts)
|
evaluate.py
CHANGED
@@ -1,59 +1,65 @@
|
|
1 |
import argparse
|
2 |
import json
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Dict
|
5 |
-
import warnings
|
6 |
|
7 |
-
from benchmark import create_benchmark
|
8 |
-
from benchmark.metrics import create_metric
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
from tqdm import tqdm
|
12 |
|
|
|
|
|
13 |
|
14 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
15 |
|
16 |
|
17 |
-
def evaluate_benchmark(
|
|
|
|
|
18 |
"""
|
19 |
Evaluate a benchmark's images using its specific metrics.
|
20 |
-
|
21 |
Args:
|
22 |
benchmark_type (str): Type of benchmark to evaluate
|
23 |
api_type (str): Type of API used to generate images
|
24 |
images_dir (Path): Base directory containing generated images
|
25 |
-
|
26 |
Returns:
|
27 |
Dict containing evaluation results
|
28 |
"""
|
29 |
benchmark = create_benchmark(benchmark_type)
|
30 |
-
|
31 |
benchmark_dir = images_dir / api_type / benchmark_type
|
32 |
metadata_file = benchmark_dir / "metadata.jsonl"
|
33 |
-
|
34 |
if not metadata_file.exists():
|
35 |
-
raise FileNotFoundError(
|
36 |
-
|
|
|
|
|
37 |
metadata = []
|
38 |
with open(metadata_file, "r") as f:
|
39 |
for line in f:
|
40 |
metadata.append(json.loads(line))
|
41 |
-
|
42 |
-
metrics = {
|
43 |
-
|
|
|
|
|
44 |
results = {
|
45 |
"api": api_type,
|
46 |
"benchmark": benchmark_type,
|
47 |
"metrics": {metric: 0.0 for metric in benchmark.metrics},
|
48 |
-
"total_images": len(metadata)
|
49 |
}
|
50 |
inference_times = []
|
51 |
-
|
52 |
for entry in tqdm(metadata):
|
53 |
image_path = benchmark_dir / entry["filepath"]
|
54 |
if not image_path.exists():
|
55 |
continue
|
56 |
-
|
57 |
for metric_type, metric in metrics.items():
|
58 |
try:
|
59 |
if metric_type == "vqa":
|
@@ -64,26 +70,30 @@ def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Pa
|
|
64 |
results["metrics"][metric_type] += score[metric_type]
|
65 |
except Exception as e:
|
66 |
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
|
67 |
-
|
68 |
inference_times.append(entry["inference_time"])
|
69 |
-
|
70 |
for metric in results["metrics"]:
|
71 |
results["metrics"][metric] /= len(metadata)
|
72 |
results["median_inference_time"] = np.median(inference_times).item()
|
73 |
-
|
74 |
return results
|
75 |
|
76 |
|
77 |
def main():
|
78 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
79 |
parser.add_argument("api_type", help="Type of API to evaluate")
|
80 |
-
parser.add_argument(
|
81 |
-
|
|
|
|
|
82 |
args = parser.parse_args()
|
83 |
-
|
84 |
results_dir = Path("evaluation_results")
|
85 |
results_dir.mkdir(exist_ok=True)
|
86 |
-
|
87 |
results_file = results_dir / f"{args.api_type}.jsonl"
|
88 |
existing_results = set()
|
89 |
|
@@ -97,15 +107,15 @@ def main():
|
|
97 |
if benchmark_type in existing_results:
|
98 |
print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
|
99 |
continue
|
100 |
-
|
101 |
try:
|
102 |
print(f"Evaluating {args.api_type}/{benchmark_type}")
|
103 |
results = evaluate_benchmark(benchmark_type, args.api_type)
|
104 |
-
|
105 |
# Append results to file
|
106 |
with open(results_file, "a") as f:
|
107 |
f.write(json.dumps(results) + "\n")
|
108 |
-
|
109 |
except Exception as e:
|
110 |
print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
|
111 |
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
+
import warnings
|
4 |
from pathlib import Path
|
5 |
from typing import Dict
|
|
|
6 |
|
|
|
|
|
7 |
import numpy as np
|
8 |
from PIL import Image
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
+
from benchmark import create_benchmark
|
12 |
+
from benchmark.metrics import create_metric
|
13 |
|
14 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
15 |
|
16 |
|
17 |
+
def evaluate_benchmark(
|
18 |
+
benchmark_type: str, api_type: str, images_dir: Path = Path("images")
|
19 |
+
) -> Dict:
|
20 |
"""
|
21 |
Evaluate a benchmark's images using its specific metrics.
|
22 |
+
|
23 |
Args:
|
24 |
benchmark_type (str): Type of benchmark to evaluate
|
25 |
api_type (str): Type of API used to generate images
|
26 |
images_dir (Path): Base directory containing generated images
|
27 |
+
|
28 |
Returns:
|
29 |
Dict containing evaluation results
|
30 |
"""
|
31 |
benchmark = create_benchmark(benchmark_type)
|
32 |
+
|
33 |
benchmark_dir = images_dir / api_type / benchmark_type
|
34 |
metadata_file = benchmark_dir / "metadata.jsonl"
|
35 |
+
|
36 |
if not metadata_file.exists():
|
37 |
+
raise FileNotFoundError(
|
38 |
+
f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first."
|
39 |
+
)
|
40 |
+
|
41 |
metadata = []
|
42 |
with open(metadata_file, "r") as f:
|
43 |
for line in f:
|
44 |
metadata.append(json.loads(line))
|
45 |
+
|
46 |
+
metrics = {
|
47 |
+
metric_type: create_metric(metric_type) for metric_type in benchmark.metrics
|
48 |
+
}
|
49 |
+
|
50 |
results = {
|
51 |
"api": api_type,
|
52 |
"benchmark": benchmark_type,
|
53 |
"metrics": {metric: 0.0 for metric in benchmark.metrics},
|
54 |
+
"total_images": len(metadata),
|
55 |
}
|
56 |
inference_times = []
|
57 |
+
|
58 |
for entry in tqdm(metadata):
|
59 |
image_path = benchmark_dir / entry["filepath"]
|
60 |
if not image_path.exists():
|
61 |
continue
|
62 |
+
|
63 |
for metric_type, metric in metrics.items():
|
64 |
try:
|
65 |
if metric_type == "vqa":
|
|
|
70 |
results["metrics"][metric_type] += score[metric_type]
|
71 |
except Exception as e:
|
72 |
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
|
73 |
+
|
74 |
inference_times.append(entry["inference_time"])
|
75 |
+
|
76 |
for metric in results["metrics"]:
|
77 |
results["metrics"][metric] /= len(metadata)
|
78 |
results["median_inference_time"] = np.median(inference_times).item()
|
79 |
+
|
80 |
return results
|
81 |
|
82 |
|
83 |
def main():
|
84 |
+
parser = argparse.ArgumentParser(
|
85 |
+
description="Evaluate generated images using benchmark-specific metrics"
|
86 |
+
)
|
87 |
parser.add_argument("api_type", help="Type of API to evaluate")
|
88 |
+
parser.add_argument(
|
89 |
+
"benchmarks", nargs="+", help="List of benchmark types to evaluate"
|
90 |
+
)
|
91 |
+
|
92 |
args = parser.parse_args()
|
93 |
+
|
94 |
results_dir = Path("evaluation_results")
|
95 |
results_dir.mkdir(exist_ok=True)
|
96 |
+
|
97 |
results_file = results_dir / f"{args.api_type}.jsonl"
|
98 |
existing_results = set()
|
99 |
|
|
|
107 |
if benchmark_type in existing_results:
|
108 |
print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
|
109 |
continue
|
110 |
+
|
111 |
try:
|
112 |
print(f"Evaluating {args.api_type}/{benchmark_type}")
|
113 |
results = evaluate_benchmark(benchmark_type, args.api_type)
|
114 |
+
|
115 |
# Append results to file
|
116 |
with open(results_file, "a") as f:
|
117 |
f.write(json.dumps(results) + "\n")
|
118 |
+
|
119 |
except Exception as e:
|
120 |
print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
|
121 |
|