Spaces:
Running
Running
Merge pull request #2 from PrunaAI/feat/david
Browse files- .gitignore +2 -1
- README.md +19 -1
- api/__init__.py +21 -18
- api/aws.py +8 -12
- api/baseline.py +5 -3
- api/fal.py +2 -2
- api/fireworks.py +3 -3
- api/flux.py +3 -3
- api/pruna.py +5 -3
- api/pruna_dev.py +3 -3
- api/replicate.py +1 -1
- api/replicate_wan.py +48 -0
- api/together.py +2 -2
- benchmark/__init__.py +13 -7
- benchmark/draw_bench.py +3 -3
- benchmark/genai_bench.py +4 -4
- benchmark/geneval.py +7 -7
- benchmark/hps.py +15 -8
- benchmark/metrics/arniqa.py +13 -4
- benchmark/metrics/clip.py +9 -3
- benchmark/metrics/clip_iqa.py +9 -5
- benchmark/metrics/hps.py +39 -24
- benchmark/metrics/image_reward.py +12 -3
- benchmark/metrics/vqa.py +14 -3
- benchmark/parti.py +2 -2
- environment.yml +1 -1
- evaluate.py +37 -27
- nils_installs.txt +0 -177
- sample.py +75 -60
.gitignore
CHANGED
@@ -175,4 +175,5 @@ cython_debug/
|
|
175 |
|
176 |
evaluation_results/
|
177 |
images/
|
178 |
-
hf_cache/
|
|
|
|
175 |
|
176 |
evaluation_results/
|
177 |
images/
|
178 |
+
hf_cache/
|
179 |
+
*.lock
|
README.md
CHANGED
@@ -1,10 +1,28 @@
|
|
1 |
# InferBench
|
2 |
Evaluate the quality and efficiency of image gen api's.
|
3 |
|
4 |
-
|
|
|
|
|
5 |
|
|
|
|
|
6 |
conda env create -f environment.yml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
|
|
8 |
|
9 |
Create .env file with all the credentials you will need.
|
10 |
|
|
|
1 |
# InferBench
|
2 |
Evaluate the quality and efficiency of image gen api's.
|
3 |
|
4 |
+
## Installation
|
5 |
+
|
6 |
+
### Install dependencies
|
7 |
|
8 |
+
Install dependencies with conda like that:
|
9 |
+
```
|
10 |
conda env create -f environment.yml
|
11 |
+
```
|
12 |
+
|
13 |
+
### Install uv
|
14 |
+
|
15 |
+
Install uv with pip like that:
|
16 |
+
|
17 |
+
```
|
18 |
+
uv venv --python 3.12
|
19 |
+
```
|
20 |
+
|
21 |
+
```
|
22 |
+
uv sync --all-groups
|
23 |
+
```
|
24 |
|
25 |
+
## Usage
|
26 |
|
27 |
Create .env file with all the credentials you will need.
|
28 |
|
api/__init__.py
CHANGED
@@ -1,31 +1,32 @@
|
|
1 |
from typing import Type
|
2 |
|
|
|
3 |
from api.baseline import BaselineAPI
|
|
|
4 |
from api.fireworks import FireworksAPI
|
5 |
from api.flux import FluxAPI
|
6 |
from api.pruna import PrunaAPI
|
7 |
from api.pruna_dev import PrunaDevAPI
|
8 |
from api.replicate import ReplicateAPI
|
9 |
from api.together import TogetherAPI
|
10 |
-
from api.fal import FalAPI
|
11 |
-
from api.aws import AWSBedrockAPI
|
12 |
|
13 |
__all__ = [
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
]
|
24 |
|
|
|
25 |
def create_api(api_type: str) -> FluxAPI:
|
26 |
"""
|
27 |
Factory function to create API instances.
|
28 |
-
|
29 |
Args:
|
30 |
api_type (str): The type of API to create. Must be one of:
|
31 |
- "baseline"
|
@@ -35,10 +36,10 @@ def create_api(api_type: str) -> FluxAPI:
|
|
35 |
- "together"
|
36 |
- "fal"
|
37 |
- "aws"
|
38 |
-
|
39 |
Returns:
|
40 |
FluxAPI: An instance of the requested API implementation
|
41 |
-
|
42 |
Raises:
|
43 |
ValueError: If an invalid API type is provided
|
44 |
"""
|
@@ -47,7 +48,7 @@ def create_api(api_type: str) -> FluxAPI:
|
|
47 |
if api_type.startswith("pruna_"):
|
48 |
speed_mode = api_type[6:] # Remove "pruna_" prefix
|
49 |
return PrunaAPI(speed_mode)
|
50 |
-
|
51 |
api_map: dict[str, Type[FluxAPI]] = {
|
52 |
"baseline": BaselineAPI,
|
53 |
"fireworks": FireworksAPI,
|
@@ -56,8 +57,10 @@ def create_api(api_type: str) -> FluxAPI:
|
|
56 |
"fal": FalAPI,
|
57 |
"aws": AWSBedrockAPI,
|
58 |
}
|
59 |
-
|
60 |
if api_type not in api_map:
|
61 |
-
raise ValueError(
|
62 |
-
|
|
|
|
|
63 |
return api_map[api_type]()
|
|
|
1 |
from typing import Type
|
2 |
|
3 |
+
from api.aws import AWSBedrockAPI
|
4 |
from api.baseline import BaselineAPI
|
5 |
+
from api.fal import FalAPI
|
6 |
from api.fireworks import FireworksAPI
|
7 |
from api.flux import FluxAPI
|
8 |
from api.pruna import PrunaAPI
|
9 |
from api.pruna_dev import PrunaDevAPI
|
10 |
from api.replicate import ReplicateAPI
|
11 |
from api.together import TogetherAPI
|
|
|
|
|
12 |
|
13 |
__all__ = [
|
14 |
+
"create_api",
|
15 |
+
"FluxAPI",
|
16 |
+
"BaselineAPI",
|
17 |
+
"FireworksAPI",
|
18 |
+
"PrunaAPI",
|
19 |
+
"ReplicateAPI",
|
20 |
+
"TogetherAPI",
|
21 |
+
"FalAPI",
|
22 |
+
"PrunaDevAPI",
|
23 |
]
|
24 |
|
25 |
+
|
26 |
def create_api(api_type: str) -> FluxAPI:
|
27 |
"""
|
28 |
Factory function to create API instances.
|
29 |
+
|
30 |
Args:
|
31 |
api_type (str): The type of API to create. Must be one of:
|
32 |
- "baseline"
|
|
|
36 |
- "together"
|
37 |
- "fal"
|
38 |
- "aws"
|
39 |
+
|
40 |
Returns:
|
41 |
FluxAPI: An instance of the requested API implementation
|
42 |
+
|
43 |
Raises:
|
44 |
ValueError: If an invalid API type is provided
|
45 |
"""
|
|
|
48 |
if api_type.startswith("pruna_"):
|
49 |
speed_mode = api_type[6:] # Remove "pruna_" prefix
|
50 |
return PrunaAPI(speed_mode)
|
51 |
+
|
52 |
api_map: dict[str, Type[FluxAPI]] = {
|
53 |
"baseline": BaselineAPI,
|
54 |
"fireworks": FireworksAPI,
|
|
|
57 |
"fal": FalAPI,
|
58 |
"aws": AWSBedrockAPI,
|
59 |
}
|
60 |
+
|
61 |
if api_type not in api_map:
|
62 |
+
raise ValueError(
|
63 |
+
f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'"
|
64 |
+
)
|
65 |
+
|
66 |
return api_map[api_type]()
|
api/aws.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
import os
|
2 |
-
import time
|
3 |
import base64
|
4 |
import json
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
-
from typing import Any
|
7 |
|
8 |
import boto3
|
9 |
from dotenv import load_dotenv
|
@@ -45,23 +44,20 @@ class AWSBedrockAPI(FluxAPI):
|
|
45 |
try:
|
46 |
# Convert request to JSON and invoke the model
|
47 |
request = json.dumps(native_request)
|
48 |
-
response = self._client.invoke_model(
|
49 |
-
|
50 |
-
body=request
|
51 |
-
)
|
52 |
-
|
53 |
# Process the response
|
54 |
model_response = json.loads(response["body"].read())
|
55 |
if not model_response.get("images"):
|
56 |
raise Exception("No images returned from AWS Bedrock API")
|
57 |
-
|
58 |
# Save the image
|
59 |
base64_image_data = model_response["images"][0]
|
60 |
self._save_image_from_base64(base64_image_data, save_path)
|
61 |
-
|
62 |
except Exception as e:
|
63 |
raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
|
64 |
-
|
65 |
end_time = time.time()
|
66 |
return end_time - start_time
|
67 |
|
@@ -70,4 +66,4 @@ class AWSBedrockAPI(FluxAPI):
|
|
70 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
71 |
image_data = base64.b64decode(base64_data)
|
72 |
with open(save_path, "wb") as f:
|
73 |
-
f.write(image_data)
|
|
|
|
|
|
|
1 |
import base64
|
2 |
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
import boto3
|
8 |
from dotenv import load_dotenv
|
|
|
44 |
try:
|
45 |
# Convert request to JSON and invoke the model
|
46 |
request = json.dumps(native_request)
|
47 |
+
response = self._client.invoke_model(modelId=self._model_id, body=request)
|
48 |
+
|
|
|
|
|
|
|
49 |
# Process the response
|
50 |
model_response = json.loads(response["body"].read())
|
51 |
if not model_response.get("images"):
|
52 |
raise Exception("No images returned from AWS Bedrock API")
|
53 |
+
|
54 |
# Save the image
|
55 |
base64_image_data = model_response["images"][0]
|
56 |
self._save_image_from_base64(base64_image_data, save_path)
|
57 |
+
|
58 |
except Exception as e:
|
59 |
raise Exception(f"Error generating image with AWS Bedrock: {str(e)}")
|
60 |
+
|
61 |
end_time = time.time()
|
62 |
return end_time - start_time
|
63 |
|
|
|
66 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
67 |
image_data = base64.b64decode(base64_data)
|
68 |
with open(save_path, "wb") as f:
|
69 |
+
f.write(image_data)
|
api/baseline.py
CHANGED
@@ -12,6 +12,7 @@ class BaselineAPI(FluxAPI):
|
|
12 |
"""
|
13 |
As our baseline, we use the Replicate API with go_fast=False.
|
14 |
"""
|
|
|
15 |
def __init__(self):
|
16 |
load_dotenv()
|
17 |
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
@@ -24,6 +25,7 @@ class BaselineAPI(FluxAPI):
|
|
24 |
|
25 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
26 |
import replicate
|
|
|
27 |
start_time = time.time()
|
28 |
result = replicate.run(
|
29 |
"black-forest-labs/flux-dev",
|
@@ -39,15 +41,15 @@ class BaselineAPI(FluxAPI):
|
|
39 |
},
|
40 |
)
|
41 |
end_time = time.time()
|
42 |
-
|
43 |
if result and len(result) > 0:
|
44 |
self._save_image_from_result(result[0], save_path)
|
45 |
else:
|
46 |
raise Exception("No result returned from Replicate API")
|
47 |
-
|
48 |
return end_time - start_time
|
49 |
|
50 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
51 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
52 |
with open(save_path, "wb") as f:
|
53 |
-
f.write(result.read())
|
|
|
12 |
"""
|
13 |
As our baseline, we use the Replicate API with go_fast=False.
|
14 |
"""
|
15 |
+
|
16 |
def __init__(self):
|
17 |
load_dotenv()
|
18 |
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
|
|
25 |
|
26 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
27 |
import replicate
|
28 |
+
|
29 |
start_time = time.time()
|
30 |
result = replicate.run(
|
31 |
"black-forest-labs/flux-dev",
|
|
|
41 |
},
|
42 |
)
|
43 |
end_time = time.time()
|
44 |
+
|
45 |
if result and len(result) > 0:
|
46 |
self._save_image_from_result(result[0], save_path)
|
47 |
else:
|
48 |
raise Exception("No result returned from Replicate API")
|
49 |
+
|
50 |
return end_time - start_time
|
51 |
|
52 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
53 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
54 |
with open(save_path, "wb") as f:
|
55 |
+
f.write(result.read())
|
api/fal.py
CHANGED
@@ -30,10 +30,10 @@ class FalAPI(FluxAPI):
|
|
30 |
},
|
31 |
)
|
32 |
end_time = time.time()
|
33 |
-
|
34 |
url = result["images"][0]["url"]
|
35 |
self._save_image_from_url(url, save_path)
|
36 |
-
|
37 |
return end_time - start_time
|
38 |
|
39 |
def _save_image_from_url(self, url: str, save_path: Path):
|
|
|
30 |
},
|
31 |
)
|
32 |
end_time = time.time()
|
33 |
+
|
34 |
url = result["images"][0]["url"]
|
35 |
self._save_image_from_url(url, save_path)
|
36 |
+
|
37 |
return end_time - start_time
|
38 |
|
39 |
def _save_image_from_url(self, url: str, save_path: Path):
|
api/fireworks.py
CHANGED
@@ -23,7 +23,7 @@ class FireworksAPI(FluxAPI):
|
|
23 |
|
24 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
start_time = time.time()
|
26 |
-
|
27 |
headers = {
|
28 |
"Content-Type": "application/json",
|
29 |
"Accept": "image/jpeg",
|
@@ -39,12 +39,12 @@ class FireworksAPI(FluxAPI):
|
|
39 |
result = requests.post(self._url, headers=headers, json=data)
|
40 |
|
41 |
end_time = time.time()
|
42 |
-
|
43 |
if result.status_code == 200:
|
44 |
self._save_image_from_result(result, save_path)
|
45 |
else:
|
46 |
raise Exception(f"Error: {result.status_code} {result.text}")
|
47 |
-
|
48 |
return end_time - start_time
|
49 |
|
50 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
|
|
23 |
|
24 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
start_time = time.time()
|
26 |
+
|
27 |
headers = {
|
28 |
"Content-Type": "application/json",
|
29 |
"Accept": "image/jpeg",
|
|
|
39 |
result = requests.post(self._url, headers=headers, json=data)
|
40 |
|
41 |
end_time = time.time()
|
42 |
+
|
43 |
if result.status_code == 200:
|
44 |
self._save_image_from_result(result, save_path)
|
45 |
else:
|
46 |
raise Exception(f"Error: {result.status_code} {result.text}")
|
47 |
+
|
48 |
return end_time - start_time
|
49 |
|
50 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
api/flux.py
CHANGED
@@ -14,7 +14,7 @@ class FluxAPI(ABC):
|
|
14 |
def name(self) -> str:
|
15 |
"""
|
16 |
The name of the API implementation.
|
17 |
-
|
18 |
Returns:
|
19 |
str: The name of the specific API implementation
|
20 |
"""
|
@@ -24,11 +24,11 @@ class FluxAPI(ABC):
|
|
24 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
"""
|
26 |
Generate an image based on the prompt and save it to the specified path.
|
27 |
-
|
28 |
Args:
|
29 |
prompt (str): The text prompt to generate the image from
|
30 |
save_path (Path): The path where the generated image should be saved
|
31 |
-
|
32 |
Returns:
|
33 |
float: The time taken for the API call in seconds
|
34 |
"""
|
|
|
14 |
def name(self) -> str:
|
15 |
"""
|
16 |
The name of the API implementation.
|
17 |
+
|
18 |
Returns:
|
19 |
str: The name of the specific API implementation
|
20 |
"""
|
|
|
24 |
def generate_image(self, prompt: str, save_path: Path) -> float:
|
25 |
"""
|
26 |
Generate an image based on the prompt and save it to the specified path.
|
27 |
+
|
28 |
Args:
|
29 |
prompt (str): The text prompt to generate the image from
|
30 |
save_path (Path): The path where the generated image should be saved
|
31 |
+
|
32 |
Returns:
|
33 |
float: The time taken for the API call in seconds
|
34 |
"""
|
api/pruna.py
CHANGED
@@ -3,8 +3,8 @@ import time
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
6 |
-
from dotenv import load_dotenv
|
7 |
import replicate
|
|
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
@@ -12,7 +12,9 @@ from api.flux import FluxAPI
|
|
12 |
class PrunaAPI(FluxAPI):
|
13 |
def __init__(self, speed_mode: str):
|
14 |
self._speed_mode = speed_mode
|
15 |
-
self._speed_mode_name =
|
|
|
|
|
16 |
load_dotenv()
|
17 |
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
18 |
if not self._api_key:
|
@@ -38,7 +40,7 @@ class PrunaAPI(FluxAPI):
|
|
38 |
},
|
39 |
)
|
40 |
end_time = time.time()
|
41 |
-
|
42 |
if result:
|
43 |
self._save_image_from_result(result, save_path)
|
44 |
else:
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
|
|
6 |
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
|
|
12 |
class PrunaAPI(FluxAPI):
|
13 |
def __init__(self, speed_mode: str):
|
14 |
self._speed_mode = speed_mode
|
15 |
+
self._speed_mode_name = (
|
16 |
+
speed_mode.split(" ")[0].strip().lower().replace(" ", "_")
|
17 |
+
)
|
18 |
load_dotenv()
|
19 |
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
20 |
if not self._api_key:
|
|
|
40 |
},
|
41 |
)
|
42 |
end_time = time.time()
|
43 |
+
|
44 |
if result:
|
45 |
self._save_image_from_result(result, save_path)
|
46 |
else:
|
api/pruna_dev.py
CHANGED
@@ -3,8 +3,8 @@ import time
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
6 |
-
from dotenv import load_dotenv
|
7 |
import replicate
|
|
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
@@ -36,7 +36,7 @@ class PrunaDevAPI(FluxAPI):
|
|
36 |
},
|
37 |
)
|
38 |
end_time = time.time()
|
39 |
-
|
40 |
if result:
|
41 |
self._save_image_from_result(result, save_path)
|
42 |
else:
|
@@ -46,4 +46,4 @@ class PrunaDevAPI(FluxAPI):
|
|
46 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
47 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
48 |
with open(save_path, "wb") as f:
|
49 |
-
f.write(result.read())
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
|
|
6 |
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
|
|
36 |
},
|
37 |
)
|
38 |
end_time = time.time()
|
39 |
+
|
40 |
if result:
|
41 |
self._save_image_from_result(result, save_path)
|
42 |
else:
|
|
|
46 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
47 |
save_path.parent.mkdir(parents=True, exist_ok=True)
|
48 |
with open(save_path, "wb") as f:
|
49 |
+
f.write(result.read())
|
api/replicate.py
CHANGED
@@ -3,8 +3,8 @@ import time
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
6 |
-
from dotenv import load_dotenv
|
7 |
import replicate
|
|
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any
|
5 |
|
|
|
6 |
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
from api.flux import FluxAPI
|
10 |
|
api/replicate_wan.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
from api.flux import FluxAPI
|
10 |
+
|
11 |
+
|
12 |
+
class ReplicateAPI(FluxAPI):
|
13 |
+
def __init__(self):
|
14 |
+
load_dotenv()
|
15 |
+
self._api_key = os.getenv("REPLICATE_API_TOKEN")
|
16 |
+
if not self._api_key:
|
17 |
+
raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
|
18 |
+
|
19 |
+
@property
|
20 |
+
def name(self) -> str:
|
21 |
+
return "replicate_go_fast"
|
22 |
+
|
23 |
+
def generate_image(self, prompt: str, save_path: Path) -> float:
|
24 |
+
start_time = time.time()
|
25 |
+
result = replicate.run(
|
26 |
+
"black-forest-labs/flux-dev",
|
27 |
+
input={
|
28 |
+
"seed": 0,
|
29 |
+
"prompt": prompt,
|
30 |
+
"go_fast": True,
|
31 |
+
"guidance": 3.5,
|
32 |
+
"num_outputs": 1,
|
33 |
+
"aspect_ratio": "1:1",
|
34 |
+
"output_format": "png",
|
35 |
+
"num_inference_steps": 28,
|
36 |
+
},
|
37 |
+
)
|
38 |
+
end_time = time.time()
|
39 |
+
if result and len(result) > 0:
|
40 |
+
self._save_image_from_result(result[0], save_path)
|
41 |
+
else:
|
42 |
+
raise Exception("No result returned from Replicate API")
|
43 |
+
return end_time - start_time
|
44 |
+
|
45 |
+
def _save_image_from_result(self, result: Any, save_path: Path):
|
46 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
47 |
+
with open(save_path, "wb") as f:
|
48 |
+
f.write(result.read())
|
api/together.py
CHANGED
@@ -33,10 +33,10 @@ class TogetherAPI(FluxAPI):
|
|
33 |
response_format="b64_json",
|
34 |
)
|
35 |
end_time = time.time()
|
36 |
-
if result and hasattr(result,
|
37 |
self._save_image_from_result(result, save_path)
|
38 |
else:
|
39 |
-
raise Exception("No result returned from Together API")
|
40 |
return end_time - start_time
|
41 |
|
42 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
|
|
33 |
response_format="b64_json",
|
34 |
)
|
35 |
end_time = time.time()
|
36 |
+
if result and hasattr(result, "data") and len(result.data) > 0:
|
37 |
self._save_image_from_result(result, save_path)
|
38 |
else:
|
39 |
+
raise Exception("No result returned from Together API")
|
40 |
return end_time - start_time
|
41 |
|
42 |
def _save_image_from_result(self, result: Any, save_path: Path):
|
benchmark/__init__.py
CHANGED
@@ -7,10 +7,14 @@ from benchmark.hps import HPSPrompts
|
|
7 |
from benchmark.parti import PartiPrompts
|
8 |
|
9 |
|
10 |
-
def create_benchmark(
|
|
|
|
|
|
|
|
|
11 |
"""
|
12 |
Factory function to create benchmark instances.
|
13 |
-
|
14 |
Args:
|
15 |
benchmark_type (str): The type of benchmark to create. Must be one of:
|
16 |
- "draw_bench"
|
@@ -18,10 +22,10 @@ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchP
|
|
18 |
- "geneval"
|
19 |
- "hps"
|
20 |
- "parti"
|
21 |
-
|
22 |
Returns:
|
23 |
An instance of the requested benchmark implementation
|
24 |
-
|
25 |
Raises:
|
26 |
ValueError: If an invalid benchmark type is provided
|
27 |
"""
|
@@ -32,8 +36,10 @@ def create_benchmark(benchmark_type: str) -> Type[DrawBenchPrompts | GenAIBenchP
|
|
32 |
"hps": HPSPrompts,
|
33 |
"parti": PartiPrompts,
|
34 |
}
|
35 |
-
|
36 |
if benchmark_type not in benchmark_map:
|
37 |
-
raise ValueError(
|
38 |
-
|
|
|
|
|
39 |
return benchmark_map[benchmark_type]()
|
|
|
7 |
from benchmark.parti import PartiPrompts
|
8 |
|
9 |
|
10 |
+
def create_benchmark(
|
11 |
+
benchmark_type: str,
|
12 |
+
) -> Type[
|
13 |
+
DrawBenchPrompts | GenAIBenchPrompts | GenEvalPrompts | HPSPrompts | PartiPrompts
|
14 |
+
]:
|
15 |
"""
|
16 |
Factory function to create benchmark instances.
|
17 |
+
|
18 |
Args:
|
19 |
benchmark_type (str): The type of benchmark to create. Must be one of:
|
20 |
- "draw_bench"
|
|
|
22 |
- "geneval"
|
23 |
- "hps"
|
24 |
- "parti"
|
25 |
+
|
26 |
Returns:
|
27 |
An instance of the requested benchmark implementation
|
28 |
+
|
29 |
Raises:
|
30 |
ValueError: If an invalid benchmark type is provided
|
31 |
"""
|
|
|
36 |
"hps": HPSPrompts,
|
37 |
"parti": PartiPrompts,
|
38 |
}
|
39 |
+
|
40 |
if benchmark_type not in benchmark_map:
|
41 |
+
raise ValueError(
|
42 |
+
f"Invalid benchmark type: {benchmark_type}. Must be one of {list(benchmark_map.keys())}"
|
43 |
+
)
|
44 |
+
|
45 |
return benchmark_map[benchmark_type]()
|
benchmark/draw_bench.py
CHANGED
@@ -7,15 +7,15 @@ from datasets import load_dataset
|
|
7 |
class DrawBenchPrompts:
|
8 |
def __init__(self):
|
9 |
self.dataset = load_dataset("shunk031/DrawBench")["test"]
|
10 |
-
|
11 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
12 |
for i, row in enumerate(self.dataset):
|
13 |
yield row["prompts"], Path(f"{i}.png")
|
14 |
-
|
15 |
@property
|
16 |
def name(self) -> str:
|
17 |
return "draw_bench"
|
18 |
-
|
19 |
@property
|
20 |
def size(self) -> int:
|
21 |
return len(self.dataset)
|
|
|
7 |
class DrawBenchPrompts:
|
8 |
def __init__(self):
|
9 |
self.dataset = load_dataset("shunk031/DrawBench")["test"]
|
10 |
+
|
11 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
12 |
for i, row in enumerate(self.dataset):
|
13 |
yield row["prompts"], Path(f"{i}.png")
|
14 |
+
|
15 |
@property
|
16 |
def name(self) -> str:
|
17 |
return "draw_bench"
|
18 |
+
|
19 |
@property
|
20 |
def size(self) -> int:
|
21 |
return len(self.dataset)
|
benchmark/genai_bench.py
CHANGED
@@ -8,8 +8,8 @@ class GenAIBenchPrompts:
|
|
8 |
def __init__(self):
|
9 |
super().__init__()
|
10 |
self._download_genai_bench_files()
|
11 |
-
prompts_path = Path(
|
12 |
-
with open(prompts_path,
|
13 |
self.prompts = [line.strip() for line in f if line.strip()]
|
14 |
|
15 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
@@ -17,13 +17,13 @@ class GenAIBenchPrompts:
|
|
17 |
yield prompt, Path(f"{i}.png")
|
18 |
|
19 |
def _download_genai_bench_files(self) -> None:
|
20 |
-
folder_name = Path(
|
21 |
folder_name.mkdir(parents=True, exist_ok=True)
|
22 |
prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
|
23 |
prompts_path = folder_name / "prompts.txt"
|
24 |
if not prompts_path.exists():
|
25 |
response = requests.get(prompts_url)
|
26 |
-
with open(prompts_path,
|
27 |
f.write(response.text)
|
28 |
|
29 |
@property
|
|
|
8 |
def __init__(self):
|
9 |
super().__init__()
|
10 |
self._download_genai_bench_files()
|
11 |
+
prompts_path = Path("downloads/genai_bench/prompts.txt")
|
12 |
+
with open(prompts_path, "r") as f:
|
13 |
self.prompts = [line.strip() for line in f if line.strip()]
|
14 |
|
15 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
|
|
17 |
yield prompt, Path(f"{i}.png")
|
18 |
|
19 |
def _download_genai_bench_files(self) -> None:
|
20 |
+
folder_name = Path("downloads/genai_bench")
|
21 |
folder_name.mkdir(parents=True, exist_ok=True)
|
22 |
prompts_url = "https://huggingface.co/datasets/zhiqiulin/GenAI-Bench-527/raw/main/prompts.txt"
|
23 |
prompts_path = folder_name / "prompts.txt"
|
24 |
if not prompts_path.exists():
|
25 |
response = requests.get(prompts_url)
|
26 |
+
with open(prompts_path, "w") as f:
|
27 |
f.write(response.text)
|
28 |
|
29 |
@property
|
benchmark/geneval.py
CHANGED
@@ -9,32 +9,32 @@ class GenEvalPrompts:
|
|
9 |
def __init__(self):
|
10 |
super().__init__()
|
11 |
self._download_geneval_file()
|
12 |
-
metadata_path = Path(
|
13 |
self.entries: List[Dict[str, Any]] = []
|
14 |
-
with open(metadata_path,
|
15 |
for line in f:
|
16 |
if line.strip():
|
17 |
self.entries.append(json.loads(line))
|
18 |
-
|
19 |
def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
|
20 |
for i, entry in enumerate(self.entries):
|
21 |
folder_name = f"{i:05d}"
|
22 |
yield entry, folder_name
|
23 |
|
24 |
def _download_geneval_file(self) -> None:
|
25 |
-
folder_name = Path(
|
26 |
folder_name.mkdir(parents=True, exist_ok=True)
|
27 |
metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
|
28 |
metadata_path = folder_name / "evaluation_metadata.jsonl"
|
29 |
if not metadata_path.exists():
|
30 |
response = requests.get(metadata_url)
|
31 |
-
with open(metadata_path,
|
32 |
f.write(response.text)
|
33 |
-
|
34 |
@property
|
35 |
def name(self) -> str:
|
36 |
return "geneval"
|
37 |
-
|
38 |
@property
|
39 |
def size(self) -> int:
|
40 |
return len(self.entries)
|
|
|
9 |
def __init__(self):
|
10 |
super().__init__()
|
11 |
self._download_geneval_file()
|
12 |
+
metadata_path = Path("downloads/geneval/evaluation_metadata.jsonl")
|
13 |
self.entries: List[Dict[str, Any]] = []
|
14 |
+
with open(metadata_path, "r") as f:
|
15 |
for line in f:
|
16 |
if line.strip():
|
17 |
self.entries.append(json.loads(line))
|
18 |
+
|
19 |
def __iter__(self) -> Iterator[Tuple[Dict[str, Any], Path]]:
|
20 |
for i, entry in enumerate(self.entries):
|
21 |
folder_name = f"{i:05d}"
|
22 |
yield entry, folder_name
|
23 |
|
24 |
def _download_geneval_file(self) -> None:
|
25 |
+
folder_name = Path("downloads/geneval")
|
26 |
folder_name.mkdir(parents=True, exist_ok=True)
|
27 |
metadata_url = "https://raw.githubusercontent.com/djghosh13/geneval/main/prompts/evaluation_metadata.jsonl"
|
28 |
metadata_path = folder_name / "evaluation_metadata.jsonl"
|
29 |
if not metadata_path.exists():
|
30 |
response = requests.get(metadata_url)
|
31 |
+
with open(metadata_path, "w") as f:
|
32 |
f.write(response.text)
|
33 |
+
|
34 |
@property
|
35 |
def name(self) -> str:
|
36 |
return "geneval"
|
37 |
+
|
38 |
@property
|
39 |
def size(self) -> int:
|
40 |
return len(self.entries)
|
benchmark/hps.py
CHANGED
@@ -9,13 +9,18 @@ import huggingface_hub
|
|
9 |
class HPSPrompts:
|
10 |
def __init__(self):
|
11 |
super().__init__()
|
12 |
-
self.hps_prompt_files = [
|
|
|
|
|
|
|
|
|
|
|
13 |
self._download_benchmark_prompts()
|
14 |
self.prompts: Dict[str, str] = {}
|
15 |
self._size = 0
|
16 |
for file in self.hps_prompt_files:
|
17 |
-
category = file.replace(
|
18 |
-
with open(os.path.join(
|
19 |
prompts = json.load(f)
|
20 |
for i, prompt in enumerate(prompts):
|
21 |
if i == 100:
|
@@ -23,24 +28,26 @@ class HPSPrompts:
|
|
23 |
filename = f"{category}_{i:03d}.png"
|
24 |
self.prompts[filename] = prompt
|
25 |
self._size += 1
|
26 |
-
|
27 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
28 |
for filename, prompt in self.prompts.items():
|
29 |
yield prompt, Path(filename)
|
30 |
-
|
31 |
@property
|
32 |
def name(self) -> str:
|
33 |
return "hps"
|
34 |
-
|
35 |
@property
|
36 |
def size(self) -> int:
|
37 |
return self._size
|
38 |
|
39 |
def _download_benchmark_prompts(self) -> None:
|
40 |
-
folder_name = Path(
|
41 |
folder_name.mkdir(parents=True, exist_ok=True)
|
42 |
for file in self.hps_prompt_files:
|
43 |
-
file_name = huggingface_hub.hf_hub_download(
|
|
|
|
|
44 |
if not os.path.exists(os.path.join(folder_name, file)):
|
45 |
os.symlink(file_name, os.path.join(folder_name, file))
|
46 |
|
|
|
9 |
class HPSPrompts:
|
10 |
def __init__(self):
|
11 |
super().__init__()
|
12 |
+
self.hps_prompt_files = [
|
13 |
+
"anime.json",
|
14 |
+
"concept-art.json",
|
15 |
+
"paintings.json",
|
16 |
+
"photo.json",
|
17 |
+
]
|
18 |
self._download_benchmark_prompts()
|
19 |
self.prompts: Dict[str, str] = {}
|
20 |
self._size = 0
|
21 |
for file in self.hps_prompt_files:
|
22 |
+
category = file.replace(".json", "")
|
23 |
+
with open(os.path.join("downloads/hps", file), "r") as f:
|
24 |
prompts = json.load(f)
|
25 |
for i, prompt in enumerate(prompts):
|
26 |
if i == 100:
|
|
|
28 |
filename = f"{category}_{i:03d}.png"
|
29 |
self.prompts[filename] = prompt
|
30 |
self._size += 1
|
31 |
+
|
32 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
33 |
for filename, prompt in self.prompts.items():
|
34 |
yield prompt, Path(filename)
|
35 |
+
|
36 |
@property
|
37 |
def name(self) -> str:
|
38 |
return "hps"
|
39 |
+
|
40 |
@property
|
41 |
def size(self) -> int:
|
42 |
return self._size
|
43 |
|
44 |
def _download_benchmark_prompts(self) -> None:
|
45 |
+
folder_name = Path("downloads/hps")
|
46 |
folder_name.mkdir(parents=True, exist_ok=True)
|
47 |
for file in self.hps_prompt_files:
|
48 |
+
file_name = huggingface_hub.hf_hub_download(
|
49 |
+
"zhwang/HPDv2", file, subfolder="benchmark", repo_type="dataset"
|
50 |
+
)
|
51 |
if not os.path.exists(os.path.join(folder_name, file)):
|
52 |
os.symlink(file_name, os.path.join(folder_name, file))
|
53 |
|
benchmark/metrics/arniqa.py
CHANGED
@@ -8,20 +8,29 @@ from torchmetrics.image.arniqa import ARNIQA
|
|
8 |
|
9 |
class ARNIQAMetric:
|
10 |
def __init__(self):
|
11 |
-
self.device = torch.device(
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
self.metric = ARNIQA(
|
13 |
regressor_dataset="koniq10k",
|
14 |
reduction="mean",
|
15 |
normalize=True,
|
16 |
-
autocast=False
|
17 |
)
|
18 |
self.metric.to(self.device)
|
|
|
19 |
@property
|
20 |
def name(self) -> str:
|
21 |
return "arniqa"
|
22 |
-
|
23 |
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
24 |
-
image_tensor =
|
|
|
|
|
25 |
image_tensor = image_tensor.unsqueeze(0).to(self.device)
|
26 |
score = self.metric(image_tensor)
|
27 |
return {"arniqa": score.item()}
|
|
|
8 |
|
9 |
class ARNIQAMetric:
|
10 |
def __init__(self):
|
11 |
+
self.device = torch.device(
|
12 |
+
"cuda"
|
13 |
+
if torch.cuda.is_available()
|
14 |
+
else "mps"
|
15 |
+
if torch.backends.mps.is_available()
|
16 |
+
else "cpu"
|
17 |
+
)
|
18 |
self.metric = ARNIQA(
|
19 |
regressor_dataset="koniq10k",
|
20 |
reduction="mean",
|
21 |
normalize=True,
|
22 |
+
autocast=False,
|
23 |
)
|
24 |
self.metric.to(self.device)
|
25 |
+
|
26 |
@property
|
27 |
def name(self) -> str:
|
28 |
return "arniqa"
|
29 |
+
|
30 |
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
31 |
+
image_tensor = (
|
32 |
+
torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
|
33 |
+
)
|
34 |
image_tensor = image_tensor.unsqueeze(0).to(self.device)
|
35 |
score = self.metric(image_tensor)
|
36 |
return {"arniqa": score.item()}
|
benchmark/metrics/clip.py
CHANGED
@@ -8,14 +8,20 @@ from torchmetrics.multimodal.clip_score import CLIPScore
|
|
8 |
|
9 |
class CLIPMetric:
|
10 |
def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
|
11 |
-
self.device = torch.device(
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
|
13 |
self.metric.to(self.device)
|
14 |
-
|
15 |
@property
|
16 |
def name(self) -> str:
|
17 |
return "clip"
|
18 |
-
|
19 |
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
20 |
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
|
21 |
image_tensor = image_tensor.to(self.device)
|
|
|
8 |
|
9 |
class CLIPMetric:
|
10 |
def __init__(self, model_name_or_path: str = "openai/clip-vit-large-patch14"):
|
11 |
+
self.device = torch.device(
|
12 |
+
"cuda"
|
13 |
+
if torch.cuda.is_available()
|
14 |
+
else "mps"
|
15 |
+
if torch.backends.mps.is_available()
|
16 |
+
else "cpu"
|
17 |
+
)
|
18 |
self.metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")
|
19 |
self.metric.to(self.device)
|
20 |
+
|
21 |
@property
|
22 |
def name(self) -> str:
|
23 |
return "clip"
|
24 |
+
|
25 |
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
26 |
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
|
27 |
image_tensor = image_tensor.to(self.device)
|
benchmark/metrics/clip_iqa.py
CHANGED
@@ -8,18 +8,22 @@ from torchmetrics.multimodal import CLIPImageQualityAssessment
|
|
8 |
|
9 |
class CLIPIQAMetric:
|
10 |
def __init__(self):
|
11 |
-
self.device = torch.device(
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
self.metric = CLIPImageQualityAssessment(
|
13 |
-
model_name_or_path="clip_iqa",
|
14 |
-
data_range=255.0,
|
15 |
-
prompts=("quality",)
|
16 |
)
|
17 |
self.metric.to(self.device)
|
18 |
|
19 |
@property
|
20 |
def name(self) -> str:
|
21 |
return "clip_iqa"
|
22 |
-
|
23 |
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
24 |
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
|
25 |
image_tensor = image_tensor.unsqueeze(0)
|
|
|
8 |
|
9 |
class CLIPIQAMetric:
|
10 |
def __init__(self):
|
11 |
+
self.device = torch.device(
|
12 |
+
"cuda"
|
13 |
+
if torch.cuda.is_available()
|
14 |
+
else "mps"
|
15 |
+
if torch.backends.mps.is_available()
|
16 |
+
else "cpu"
|
17 |
+
)
|
18 |
self.metric = CLIPImageQualityAssessment(
|
19 |
+
model_name_or_path="clip_iqa", data_range=255.0, prompts=("quality",)
|
|
|
|
|
20 |
)
|
21 |
self.metric.to(self.device)
|
22 |
|
23 |
@property
|
24 |
def name(self) -> str:
|
25 |
return "clip_iqa"
|
26 |
+
|
27 |
def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
|
28 |
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float()
|
29 |
image_tensor = image_tensor.unsqueeze(0)
|
benchmark/metrics/hps.py
CHANGED
@@ -1,26 +1,32 @@
|
|
1 |
import os
|
2 |
from typing import Dict
|
3 |
|
|
|
4 |
import torch
|
5 |
-
from PIL import Image
|
6 |
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
7 |
-
import
|
8 |
-
from
|
9 |
|
10 |
|
11 |
class HPSMetric:
|
12 |
def __init__(self):
|
13 |
self.hps_version = "v2.1"
|
14 |
-
self.device =
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
self.model_dict = {}
|
16 |
self._initialize_model()
|
17 |
-
|
18 |
def _initialize_model(self):
|
19 |
if not self.model_dict:
|
20 |
model, preprocess_train, preprocess_val = create_model_and_transforms(
|
21 |
-
|
22 |
-
|
23 |
-
precision=
|
24 |
device=self.device,
|
25 |
jit=False,
|
26 |
force_quick_gelu=False,
|
@@ -34,44 +40,53 @@ class HPSMetric:
|
|
34 |
aug_cfg={},
|
35 |
output_dict=True,
|
36 |
with_score_predictor=False,
|
37 |
-
with_region_predictor=False
|
38 |
)
|
39 |
-
self.model_dict[
|
40 |
-
self.model_dict[
|
41 |
-
|
42 |
# Load checkpoint
|
43 |
if not os.path.exists(root_path):
|
44 |
os.makedirs(root_path)
|
45 |
-
cp = huggingface_hub.hf_hub_download(
|
46 |
-
|
|
|
|
|
47 |
checkpoint = torch.load(cp, map_location=self.device)
|
48 |
-
model.load_state_dict(checkpoint[
|
49 |
-
self.tokenizer = get_tokenizer(
|
50 |
model = model.to(self.device)
|
51 |
model.eval()
|
52 |
-
|
53 |
@property
|
54 |
def name(self) -> str:
|
55 |
return "hps"
|
56 |
-
|
57 |
def compute_score(
|
58 |
self,
|
59 |
image: Image.Image,
|
60 |
prompt: str,
|
61 |
) -> Dict[str, float]:
|
62 |
-
model = self.model_dict[
|
63 |
-
preprocess_val = self.model_dict[
|
64 |
-
|
65 |
with torch.no_grad():
|
66 |
# Process the image
|
67 |
-
image_tensor =
|
|
|
|
|
|
|
|
|
68 |
# Process the prompt
|
69 |
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
70 |
# Calculate the HPS
|
71 |
with torch.cuda.amp.autocast():
|
72 |
outputs = model(image_tensor, text)
|
73 |
-
image_features, text_features =
|
|
|
|
|
|
|
74 |
logits_per_image = image_features @ text_features.T
|
75 |
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
76 |
-
|
77 |
return {"hps": float(hps_score[0])}
|
|
|
1 |
import os
|
2 |
from typing import Dict
|
3 |
|
4 |
+
import huggingface_hub
|
5 |
import torch
|
|
|
6 |
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
|
7 |
+
from hpsv2.utils import hps_version_map, root_path
|
8 |
+
from PIL import Image
|
9 |
|
10 |
|
11 |
class HPSMetric:
|
12 |
def __init__(self):
|
13 |
self.hps_version = "v2.1"
|
14 |
+
self.device = torch.device(
|
15 |
+
"cuda"
|
16 |
+
if torch.cuda.is_available()
|
17 |
+
else "mps"
|
18 |
+
if torch.backends.mps.is_available()
|
19 |
+
else "cpu"
|
20 |
+
)
|
21 |
self.model_dict = {}
|
22 |
self._initialize_model()
|
23 |
+
|
24 |
def _initialize_model(self):
|
25 |
if not self.model_dict:
|
26 |
model, preprocess_train, preprocess_val = create_model_and_transforms(
|
27 |
+
"ViT-H-14",
|
28 |
+
"laion2B-s32B-b79K",
|
29 |
+
precision="amp",
|
30 |
device=self.device,
|
31 |
jit=False,
|
32 |
force_quick_gelu=False,
|
|
|
40 |
aug_cfg={},
|
41 |
output_dict=True,
|
42 |
with_score_predictor=False,
|
43 |
+
with_region_predictor=False,
|
44 |
)
|
45 |
+
self.model_dict["model"] = model
|
46 |
+
self.model_dict["preprocess_val"] = preprocess_val
|
47 |
+
|
48 |
# Load checkpoint
|
49 |
if not os.path.exists(root_path):
|
50 |
os.makedirs(root_path)
|
51 |
+
cp = huggingface_hub.hf_hub_download(
|
52 |
+
"xswu/HPSv2", hps_version_map[self.hps_version]
|
53 |
+
)
|
54 |
+
|
55 |
checkpoint = torch.load(cp, map_location=self.device)
|
56 |
+
model.load_state_dict(checkpoint["state_dict"])
|
57 |
+
self.tokenizer = get_tokenizer("ViT-H-14")
|
58 |
model = model.to(self.device)
|
59 |
model.eval()
|
60 |
+
|
61 |
@property
|
62 |
def name(self) -> str:
|
63 |
return "hps"
|
64 |
+
|
65 |
def compute_score(
|
66 |
self,
|
67 |
image: Image.Image,
|
68 |
prompt: str,
|
69 |
) -> Dict[str, float]:
|
70 |
+
model = self.model_dict["model"]
|
71 |
+
preprocess_val = self.model_dict["preprocess_val"]
|
72 |
+
|
73 |
with torch.no_grad():
|
74 |
# Process the image
|
75 |
+
image_tensor = (
|
76 |
+
preprocess_val(image)
|
77 |
+
.unsqueeze(0)
|
78 |
+
.to(device=self.device, non_blocking=True)
|
79 |
+
)
|
80 |
# Process the prompt
|
81 |
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
82 |
# Calculate the HPS
|
83 |
with torch.cuda.amp.autocast():
|
84 |
outputs = model(image_tensor, text)
|
85 |
+
image_features, text_features = (
|
86 |
+
outputs["image_features"],
|
87 |
+
outputs["text_features"],
|
88 |
+
)
|
89 |
logits_per_image = image_features @ text_features.T
|
90 |
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
91 |
+
|
92 |
return {"hps": float(hps_score[0])}
|
benchmark/metrics/image_reward.py
CHANGED
@@ -3,17 +3,26 @@ import tempfile
|
|
3 |
from typing import Dict
|
4 |
|
5 |
import ImageReward as RM
|
|
|
6 |
from PIL import Image
|
7 |
|
8 |
|
9 |
class ImageRewardMetric:
|
10 |
def __init__(self):
|
11 |
-
self.
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
@property
|
14 |
def name(self) -> str:
|
15 |
return "image_reward"
|
16 |
-
|
17 |
def compute_score(
|
18 |
self,
|
19 |
image: Image.Image,
|
|
|
3 |
from typing import Dict
|
4 |
|
5 |
import ImageReward as RM
|
6 |
+
import torch
|
7 |
from PIL import Image
|
8 |
|
9 |
|
10 |
class ImageRewardMetric:
|
11 |
def __init__(self):
|
12 |
+
self.device = torch.device(
|
13 |
+
"cuda"
|
14 |
+
if torch.cuda.is_available()
|
15 |
+
else "mps"
|
16 |
+
if torch.backends.mps.is_available()
|
17 |
+
else "cpu"
|
18 |
+
)
|
19 |
+
|
20 |
+
self.model = RM.load("ImageReward-v1.0", device=str(self.device))
|
21 |
+
|
22 |
@property
|
23 |
def name(self) -> str:
|
24 |
return "image_reward"
|
25 |
+
|
26 |
def compute_score(
|
27 |
self,
|
28 |
image: Image.Image,
|
benchmark/metrics/vqa.py
CHANGED
@@ -2,15 +2,26 @@ from pathlib import Path
|
|
2 |
from typing import Dict
|
3 |
|
4 |
import t2v_metrics
|
|
|
|
|
5 |
|
6 |
class VQAMetric:
|
7 |
def __init__(self):
|
8 |
-
self.
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
@property
|
11 |
def name(self) -> str:
|
12 |
return "vqa_score"
|
13 |
-
|
14 |
def compute_score(
|
15 |
self,
|
16 |
image_path: Path,
|
|
|
2 |
from typing import Dict
|
3 |
|
4 |
import t2v_metrics
|
5 |
+
import torch
|
6 |
+
|
7 |
|
8 |
class VQAMetric:
|
9 |
def __init__(self):
|
10 |
+
self.device = torch.device(
|
11 |
+
"cuda"
|
12 |
+
if torch.cuda.is_available()
|
13 |
+
else "mps"
|
14 |
+
if torch.backends.mps.is_available()
|
15 |
+
else "cpu"
|
16 |
+
)
|
17 |
+
self.metric = t2v_metrics.VQAScore(
|
18 |
+
model="clip-flant5-xxl", device=str(self.device)
|
19 |
+
)
|
20 |
+
|
21 |
@property
|
22 |
def name(self) -> str:
|
23 |
return "vqa_score"
|
24 |
+
|
25 |
def compute_score(
|
26 |
self,
|
27 |
image_path: Path,
|
benchmark/parti.py
CHANGED
@@ -14,11 +14,11 @@ class PartiPrompts:
|
|
14 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
15 |
for i, prompt in enumerate(self.prompts):
|
16 |
yield prompt, Path(f"{i}.png")
|
17 |
-
|
18 |
@property
|
19 |
def name(self) -> str:
|
20 |
return "parti"
|
21 |
-
|
22 |
@property
|
23 |
def size(self) -> int:
|
24 |
return len(self.prompts)
|
|
|
14 |
def __iter__(self) -> Iterator[Tuple[str, Path]]:
|
15 |
for i, prompt in enumerate(self.prompts):
|
16 |
yield prompt, Path(f"{i}.png")
|
17 |
+
|
18 |
@property
|
19 |
def name(self) -> str:
|
20 |
return "parti"
|
21 |
+
|
22 |
@property
|
23 |
def size(self) -> int:
|
24 |
return len(self.prompts)
|
environment.yml
CHANGED
@@ -12,7 +12,7 @@ dependencies:
|
|
12 |
- tqdm
|
13 |
- pip
|
14 |
- pip:
|
15 |
-
- datasets
|
16 |
- fal-client>=0.5.9
|
17 |
- hpsv2>=1.2.0
|
18 |
- huggingface-hub>=0.30.2
|
|
|
12 |
- tqdm
|
13 |
- pip
|
14 |
- pip:
|
15 |
+
- datasets==3.6.0
|
16 |
- fal-client>=0.5.9
|
17 |
- hpsv2>=1.2.0
|
18 |
- huggingface-hub>=0.30.2
|
evaluate.py
CHANGED
@@ -1,59 +1,65 @@
|
|
1 |
import argparse
|
2 |
import json
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Dict
|
5 |
-
import warnings
|
6 |
|
7 |
-
from benchmark import create_benchmark
|
8 |
-
from benchmark.metrics import create_metric
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
from tqdm import tqdm
|
12 |
|
|
|
|
|
13 |
|
14 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
15 |
|
16 |
|
17 |
-
def evaluate_benchmark(
|
|
|
|
|
18 |
"""
|
19 |
Evaluate a benchmark's images using its specific metrics.
|
20 |
-
|
21 |
Args:
|
22 |
benchmark_type (str): Type of benchmark to evaluate
|
23 |
api_type (str): Type of API used to generate images
|
24 |
images_dir (Path): Base directory containing generated images
|
25 |
-
|
26 |
Returns:
|
27 |
Dict containing evaluation results
|
28 |
"""
|
29 |
benchmark = create_benchmark(benchmark_type)
|
30 |
-
|
31 |
benchmark_dir = images_dir / api_type / benchmark_type
|
32 |
metadata_file = benchmark_dir / "metadata.jsonl"
|
33 |
-
|
34 |
if not metadata_file.exists():
|
35 |
-
raise FileNotFoundError(
|
36 |
-
|
|
|
|
|
37 |
metadata = []
|
38 |
with open(metadata_file, "r") as f:
|
39 |
for line in f:
|
40 |
metadata.append(json.loads(line))
|
41 |
-
|
42 |
-
metrics = {
|
43 |
-
|
|
|
|
|
44 |
results = {
|
45 |
"api": api_type,
|
46 |
"benchmark": benchmark_type,
|
47 |
"metrics": {metric: 0.0 for metric in benchmark.metrics},
|
48 |
-
"total_images": len(metadata)
|
49 |
}
|
50 |
inference_times = []
|
51 |
-
|
52 |
for entry in tqdm(metadata):
|
53 |
image_path = benchmark_dir / entry["filepath"]
|
54 |
if not image_path.exists():
|
55 |
continue
|
56 |
-
|
57 |
for metric_type, metric in metrics.items():
|
58 |
try:
|
59 |
if metric_type == "vqa":
|
@@ -64,26 +70,30 @@ def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Pa
|
|
64 |
results["metrics"][metric_type] += score[metric_type]
|
65 |
except Exception as e:
|
66 |
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
|
67 |
-
|
68 |
inference_times.append(entry["inference_time"])
|
69 |
-
|
70 |
for metric in results["metrics"]:
|
71 |
results["metrics"][metric] /= len(metadata)
|
72 |
results["median_inference_time"] = np.median(inference_times).item()
|
73 |
-
|
74 |
return results
|
75 |
|
76 |
|
77 |
def main():
|
78 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
79 |
parser.add_argument("api_type", help="Type of API to evaluate")
|
80 |
-
parser.add_argument(
|
81 |
-
|
|
|
|
|
82 |
args = parser.parse_args()
|
83 |
-
|
84 |
results_dir = Path("evaluation_results")
|
85 |
results_dir.mkdir(exist_ok=True)
|
86 |
-
|
87 |
results_file = results_dir / f"{args.api_type}.jsonl"
|
88 |
existing_results = set()
|
89 |
|
@@ -97,15 +107,15 @@ def main():
|
|
97 |
if benchmark_type in existing_results:
|
98 |
print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
|
99 |
continue
|
100 |
-
|
101 |
try:
|
102 |
print(f"Evaluating {args.api_type}/{benchmark_type}")
|
103 |
results = evaluate_benchmark(benchmark_type, args.api_type)
|
104 |
-
|
105 |
# Append results to file
|
106 |
with open(results_file, "a") as f:
|
107 |
f.write(json.dumps(results) + "\n")
|
108 |
-
|
109 |
except Exception as e:
|
110 |
print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
|
111 |
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
+
import warnings
|
4 |
from pathlib import Path
|
5 |
from typing import Dict
|
|
|
6 |
|
|
|
|
|
7 |
import numpy as np
|
8 |
from PIL import Image
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
+
from benchmark import create_benchmark
|
12 |
+
from benchmark.metrics import create_metric
|
13 |
|
14 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
15 |
|
16 |
|
17 |
+
def evaluate_benchmark(
|
18 |
+
benchmark_type: str, api_type: str, images_dir: Path = Path("images")
|
19 |
+
) -> Dict:
|
20 |
"""
|
21 |
Evaluate a benchmark's images using its specific metrics.
|
22 |
+
|
23 |
Args:
|
24 |
benchmark_type (str): Type of benchmark to evaluate
|
25 |
api_type (str): Type of API used to generate images
|
26 |
images_dir (Path): Base directory containing generated images
|
27 |
+
|
28 |
Returns:
|
29 |
Dict containing evaluation results
|
30 |
"""
|
31 |
benchmark = create_benchmark(benchmark_type)
|
32 |
+
|
33 |
benchmark_dir = images_dir / api_type / benchmark_type
|
34 |
metadata_file = benchmark_dir / "metadata.jsonl"
|
35 |
+
|
36 |
if not metadata_file.exists():
|
37 |
+
raise FileNotFoundError(
|
38 |
+
f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first."
|
39 |
+
)
|
40 |
+
|
41 |
metadata = []
|
42 |
with open(metadata_file, "r") as f:
|
43 |
for line in f:
|
44 |
metadata.append(json.loads(line))
|
45 |
+
|
46 |
+
metrics = {
|
47 |
+
metric_type: create_metric(metric_type) for metric_type in benchmark.metrics
|
48 |
+
}
|
49 |
+
|
50 |
results = {
|
51 |
"api": api_type,
|
52 |
"benchmark": benchmark_type,
|
53 |
"metrics": {metric: 0.0 for metric in benchmark.metrics},
|
54 |
+
"total_images": len(metadata),
|
55 |
}
|
56 |
inference_times = []
|
57 |
+
|
58 |
for entry in tqdm(metadata):
|
59 |
image_path = benchmark_dir / entry["filepath"]
|
60 |
if not image_path.exists():
|
61 |
continue
|
62 |
+
|
63 |
for metric_type, metric in metrics.items():
|
64 |
try:
|
65 |
if metric_type == "vqa":
|
|
|
70 |
results["metrics"][metric_type] += score[metric_type]
|
71 |
except Exception as e:
|
72 |
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
|
73 |
+
|
74 |
inference_times.append(entry["inference_time"])
|
75 |
+
|
76 |
for metric in results["metrics"]:
|
77 |
results["metrics"][metric] /= len(metadata)
|
78 |
results["median_inference_time"] = np.median(inference_times).item()
|
79 |
+
|
80 |
return results
|
81 |
|
82 |
|
83 |
def main():
|
84 |
+
parser = argparse.ArgumentParser(
|
85 |
+
description="Evaluate generated images using benchmark-specific metrics"
|
86 |
+
)
|
87 |
parser.add_argument("api_type", help="Type of API to evaluate")
|
88 |
+
parser.add_argument(
|
89 |
+
"benchmarks", nargs="+", help="List of benchmark types to evaluate"
|
90 |
+
)
|
91 |
+
|
92 |
args = parser.parse_args()
|
93 |
+
|
94 |
results_dir = Path("evaluation_results")
|
95 |
results_dir.mkdir(exist_ok=True)
|
96 |
+
|
97 |
results_file = results_dir / f"{args.api_type}.jsonl"
|
98 |
existing_results = set()
|
99 |
|
|
|
107 |
if benchmark_type in existing_results:
|
108 |
print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
|
109 |
continue
|
110 |
+
|
111 |
try:
|
112 |
print(f"Evaluating {args.api_type}/{benchmark_type}")
|
113 |
results = evaluate_benchmark(benchmark_type, args.api_type)
|
114 |
+
|
115 |
# Append results to file
|
116 |
with open(results_file, "a") as f:
|
117 |
f.write(json.dumps(results) + "\n")
|
118 |
+
|
119 |
except Exception as e:
|
120 |
print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
|
121 |
|
nils_installs.txt
DELETED
@@ -1,177 +0,0 @@
|
|
1 |
-
accelerate==1.7.0
|
2 |
-
aiohappyeyeballs==2.6.1
|
3 |
-
aiohttp==3.12.12
|
4 |
-
aiosignal==1.3.2
|
5 |
-
annotated-types==0.7.0
|
6 |
-
antlr4-python3-runtime==4.9.3
|
7 |
-
anyio==4.9.0
|
8 |
-
args==0.1.0
|
9 |
-
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
|
10 |
-
attrs==25.3.0
|
11 |
-
beautifulsoup4==4.13.4
|
12 |
-
boto3==1.38.33
|
13 |
-
botocore==1.38.33
|
14 |
-
braceexpand==0.1.7
|
15 |
-
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1749229842835/work
|
16 |
-
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1746569525376/work/certifi
|
17 |
-
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725560558132/work
|
18 |
-
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1746214863626/work
|
19 |
-
click==8.1.8
|
20 |
-
clint==0.5.1
|
21 |
-
clip==0.2.0
|
22 |
-
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
|
23 |
-
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
|
24 |
-
contourpy==1.3.2
|
25 |
-
cycler==0.12.1
|
26 |
-
datasets==3.6.0
|
27 |
-
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321241074/work
|
28 |
-
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
|
29 |
-
diffusers==0.31.0
|
30 |
-
dill==0.3.8
|
31 |
-
distro==1.9.0
|
32 |
-
einops==0.8.1
|
33 |
-
eval_type_backport==0.2.2
|
34 |
-
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
|
35 |
-
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
|
36 |
-
fairscale==0.4.13
|
37 |
-
fal_client==0.7.0
|
38 |
-
filelock==3.18.0
|
39 |
-
fire==0.4.0
|
40 |
-
fonttools==4.58.2
|
41 |
-
frozenlist==1.7.0
|
42 |
-
fsspec==2025.3.0
|
43 |
-
ftfy==6.3.1
|
44 |
-
gdown==5.2.0
|
45 |
-
h11==0.16.0
|
46 |
-
h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
|
47 |
-
hf-xet==1.1.3
|
48 |
-
hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
|
49 |
-
hpsv2==1.2.0
|
50 |
-
httpcore==1.0.9
|
51 |
-
httpx==0.28.1
|
52 |
-
httpx-sse==0.4.0
|
53 |
-
huggingface-hub==0.32.5
|
54 |
-
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
|
55 |
-
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
|
56 |
-
image-reward==1.5
|
57 |
-
importlib_metadata==8.7.0
|
58 |
-
iniconfig==2.1.0
|
59 |
-
iopath==0.1.10
|
60 |
-
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
|
61 |
-
ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1748713870/work
|
62 |
-
ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
|
63 |
-
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
|
64 |
-
Jinja2==3.1.6
|
65 |
-
jiter==0.10.0
|
66 |
-
jmespath==1.0.1
|
67 |
-
joblib==1.5.1
|
68 |
-
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
|
69 |
-
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work
|
70 |
-
kiwisolver==1.4.8
|
71 |
-
lightning-utilities==0.14.3
|
72 |
-
markdown-it-py==3.0.0
|
73 |
-
MarkupSafe==3.0.2
|
74 |
-
matplotlib==3.10.3
|
75 |
-
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
|
76 |
-
mdurl==0.1.2
|
77 |
-
mpmath==1.3.0
|
78 |
-
multidict==6.4.4
|
79 |
-
multiprocess==0.70.16
|
80 |
-
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
|
81 |
-
networkx==3.5
|
82 |
-
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1749430504934/work/dist/numpy-2.3.0-cp312-cp312-linux_x86_64.whl#sha256=3c4437a0cbe50dbae872ad4cd8dc5316009165bce459c4ffe2c46cd30aba13d4
|
83 |
-
nvidia-cublas-cu12==12.6.4.1
|
84 |
-
nvidia-cuda-cupti-cu12==12.6.80
|
85 |
-
nvidia-cuda-nvrtc-cu12==12.6.77
|
86 |
-
nvidia-cuda-runtime-cu12==12.6.77
|
87 |
-
nvidia-cudnn-cu12==9.5.1.17
|
88 |
-
nvidia-cufft-cu12==11.3.0.4
|
89 |
-
nvidia-cufile-cu12==1.11.1.6
|
90 |
-
nvidia-curand-cu12==10.3.7.77
|
91 |
-
nvidia-cusolver-cu12==11.7.1.2
|
92 |
-
nvidia-cusparse-cu12==12.5.4.2
|
93 |
-
nvidia-cusparselt-cu12==0.6.3
|
94 |
-
nvidia-nccl-cu12==2.26.2
|
95 |
-
nvidia-nvjitlink-cu12==12.6.85
|
96 |
-
nvidia-nvtx-cu12==12.6.77
|
97 |
-
omegaconf==2.3.0
|
98 |
-
open_clip_torch==2.32.0
|
99 |
-
openai==1.85.0
|
100 |
-
opencv-python==4.11.0
|
101 |
-
opencv-python-headless==4.11.0
|
102 |
-
packaging==25.0
|
103 |
-
pandas==2.3.0
|
104 |
-
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
|
105 |
-
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
|
106 |
-
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
|
107 |
-
pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1746646208260/work
|
108 |
-
piq==0.8.0
|
109 |
-
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
|
110 |
-
pluggy==1.6.0
|
111 |
-
portalocker==3.1.1
|
112 |
-
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
|
113 |
-
propcache==0.3.2
|
114 |
-
protobuf==3.20.3
|
115 |
-
psutil==7.0.0
|
116 |
-
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
|
117 |
-
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
|
118 |
-
pyarrow==20.0.0
|
119 |
-
pycocoevalcap==1.2
|
120 |
-
pycocotools==2.0.10
|
121 |
-
pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
|
122 |
-
pydantic==2.11.5
|
123 |
-
pydantic_core==2.33.2
|
124 |
-
Pygments==2.19.1
|
125 |
-
pyparsing==3.2.3
|
126 |
-
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
|
127 |
-
pytest==7.2.0
|
128 |
-
pytest-split==0.8.0
|
129 |
-
python-dateutil==2.9.0.post0
|
130 |
-
python-dotenv @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dotenv_1742948348/work
|
131 |
-
pytz==2025.2
|
132 |
-
PyYAML==6.0.2
|
133 |
-
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1743831245863/work
|
134 |
-
regex==2024.11.6
|
135 |
-
replicate==1.0.7
|
136 |
-
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1749498106507/work
|
137 |
-
rich==14.0.0
|
138 |
-
s3transfer==0.13.0
|
139 |
-
safetensors==0.5.3
|
140 |
-
scikit-learn==1.7.0
|
141 |
-
scipy==1.15.3
|
142 |
-
sentencepiece==0.2.0
|
143 |
-
setuptools==80.9.0
|
144 |
-
shellingham==1.5.4
|
145 |
-
six==1.17.0
|
146 |
-
sniffio==1.3.1
|
147 |
-
soupsieve==2.7
|
148 |
-
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
|
149 |
-
sympy==1.14.0
|
150 |
-
t2v_metrics==1.2
|
151 |
-
tabulate==0.9.0
|
152 |
-
termcolor==3.1.0
|
153 |
-
threadpoolctl==3.6.0
|
154 |
-
tiktoken==0.9.0
|
155 |
-
timm==0.6.13
|
156 |
-
together==1.5.11
|
157 |
-
tokenizers==0.15.2
|
158 |
-
torch==2.7.1
|
159 |
-
torchmetrics==1.7.2
|
160 |
-
torchvision==0.22.1
|
161 |
-
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003300911/work
|
162 |
-
tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work
|
163 |
-
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
|
164 |
-
transformers==4.36.1
|
165 |
-
triton==3.3.1
|
166 |
-
typer==0.15.4
|
167 |
-
typing-inspection==0.4.1
|
168 |
-
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1748959427/work
|
169 |
-
tzdata==2025.2
|
170 |
-
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1744323578849/work
|
171 |
-
wcwidth==0.2.13
|
172 |
-
webdataset==0.2.111
|
173 |
-
wheel==0.45.1
|
174 |
-
xxhash==3.5.0
|
175 |
-
yarl==1.20.1
|
176 |
-
zipp==3.23.0
|
177 |
-
zstandard==0.23.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample.py
CHANGED
@@ -12,19 +12,19 @@ from benchmark import create_benchmark
|
|
12 |
def generate_images(api_type: str, benchmarks: List[str]):
|
13 |
images_dir = Path("images")
|
14 |
api = create_api(api_type)
|
15 |
-
|
16 |
api_dir = images_dir / api_type
|
17 |
api_dir.mkdir(parents=True, exist_ok=True)
|
18 |
-
|
19 |
for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
|
20 |
print(f"\nProcessing benchmark: {benchmark_type}")
|
21 |
-
|
22 |
benchmark = create_benchmark(benchmark_type)
|
23 |
-
|
24 |
if benchmark_type == "geneval":
|
25 |
benchmark_dir = api_dir / benchmark_type
|
26 |
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
27 |
-
|
28 |
metadata_file = benchmark_dir / "metadata.jsonl"
|
29 |
existing_metadata = {}
|
30 |
if metadata_file.exists():
|
@@ -32,35 +32,44 @@ def generate_images(api_type: str, benchmarks: List[str]):
|
|
32 |
for line in f:
|
33 |
entry = json.loads(line)
|
34 |
existing_metadata[entry["filepath"]] = entry
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
else:
|
61 |
benchmark_dir = api_dir / benchmark_type
|
62 |
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
63 |
-
|
64 |
metadata_file = benchmark_dir / "metadata.jsonl"
|
65 |
existing_metadata = {}
|
66 |
if metadata_file.exists():
|
@@ -68,41 +77,47 @@ def generate_images(api_type: str, benchmarks: List[str]):
|
|
68 |
for line in f:
|
69 |
entry = json.loads(line)
|
70 |
existing_metadata[entry["filepath"]] = entry
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
97 |
|
98 |
|
99 |
def main():
|
100 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
101 |
parser.add_argument("api_type", help="Type of API to use for image generation")
|
102 |
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
|
103 |
-
|
104 |
args = parser.parse_args()
|
105 |
-
|
106 |
generate_images(args.api_type, args.benchmarks)
|
107 |
|
108 |
|
|
|
12 |
def generate_images(api_type: str, benchmarks: List[str]):
|
13 |
images_dir = Path("images")
|
14 |
api = create_api(api_type)
|
15 |
+
|
16 |
api_dir = images_dir / api_type
|
17 |
api_dir.mkdir(parents=True, exist_ok=True)
|
18 |
+
|
19 |
for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
|
20 |
print(f"\nProcessing benchmark: {benchmark_type}")
|
21 |
+
|
22 |
benchmark = create_benchmark(benchmark_type)
|
23 |
+
|
24 |
if benchmark_type == "geneval":
|
25 |
benchmark_dir = api_dir / benchmark_type
|
26 |
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
27 |
+
|
28 |
metadata_file = benchmark_dir / "metadata.jsonl"
|
29 |
existing_metadata = {}
|
30 |
if metadata_file.exists():
|
|
|
32 |
for line in f:
|
33 |
entry = json.loads(line)
|
34 |
existing_metadata[entry["filepath"]] = entry
|
35 |
+
|
36 |
+
with open(metadata_file, "a") as f:
|
37 |
+
for metadata, folder_name in tqdm(
|
38 |
+
benchmark,
|
39 |
+
desc=f"Generating images for {benchmark_type}",
|
40 |
+
leave=False,
|
41 |
+
):
|
42 |
+
sample_path = benchmark_dir / folder_name
|
43 |
+
samples_path = sample_path / "samples"
|
44 |
+
samples_path.mkdir(parents=True, exist_ok=True)
|
45 |
+
image_path = samples_path / "0000.png"
|
46 |
+
|
47 |
+
if image_path.exists():
|
48 |
+
continue
|
49 |
+
|
50 |
+
try:
|
51 |
+
inference_time = api.generate_image(
|
52 |
+
metadata["prompt"], image_path
|
53 |
+
)
|
54 |
+
|
55 |
+
metadata_entry = {
|
56 |
+
"filepath": str(image_path),
|
57 |
+
"prompt": metadata["prompt"],
|
58 |
+
"inference_time": inference_time,
|
59 |
+
}
|
60 |
+
|
61 |
+
f.write(json.dumps(metadata_entry) + "\n")
|
62 |
+
|
63 |
+
except Exception as e:
|
64 |
+
print(
|
65 |
+
f"\nError generating image for prompt: {metadata['prompt']}"
|
66 |
+
)
|
67 |
+
print(f"Error: {str(e)}")
|
68 |
+
continue
|
69 |
else:
|
70 |
benchmark_dir = api_dir / benchmark_type
|
71 |
benchmark_dir.mkdir(parents=True, exist_ok=True)
|
72 |
+
|
73 |
metadata_file = benchmark_dir / "metadata.jsonl"
|
74 |
existing_metadata = {}
|
75 |
if metadata_file.exists():
|
|
|
77 |
for line in f:
|
78 |
entry = json.loads(line)
|
79 |
existing_metadata[entry["filepath"]] = entry
|
80 |
+
|
81 |
+
with open(metadata_file, "a") as f:
|
82 |
+
for prompt, image_path in tqdm(
|
83 |
+
benchmark,
|
84 |
+
desc=f"Generating images for {benchmark_type}",
|
85 |
+
leave=False,
|
86 |
+
):
|
87 |
+
if image_path in existing_metadata:
|
88 |
+
continue
|
89 |
+
|
90 |
+
full_image_path = benchmark_dir / image_path
|
91 |
+
|
92 |
+
if full_image_path.exists():
|
93 |
+
continue
|
94 |
+
|
95 |
+
try:
|
96 |
+
inference_time = api.generate_image(prompt, full_image_path)
|
97 |
+
|
98 |
+
metadata_entry = {
|
99 |
+
"filepath": str(image_path),
|
100 |
+
"prompt": prompt,
|
101 |
+
"inference_time": inference_time,
|
102 |
+
}
|
103 |
+
|
104 |
+
f.write(json.dumps(metadata_entry) + "\n")
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
print(f"\nError generating image for prompt: {prompt}")
|
108 |
+
print(f"Error: {str(e)}")
|
109 |
+
continue
|
110 |
|
111 |
|
112 |
def main():
|
113 |
+
parser = argparse.ArgumentParser(
|
114 |
+
description="Generate images for specified benchmarks using a given API"
|
115 |
+
)
|
116 |
parser.add_argument("api_type", help="Type of API to use for image generation")
|
117 |
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
|
118 |
+
|
119 |
args = parser.parse_args()
|
120 |
+
|
121 |
generate_images(args.api_type, args.benchmarks)
|
122 |
|
123 |
|