nifleisch commited on
Commit
4f41410
·
1 Parent(s): 2c50826

fix: fix several erros

Browse files
.env.example ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ FIREWORKS_API_TOKEN=your_fireworks_api_token_here
2
+ REPLICATE_API_TOKEN=your_replicate_api_token_here
3
+ FAL_KEY=your_fal_key_here
4
+ TOGETHER_API_KEY=your_together_api_key_here
.gitignore CHANGED
@@ -172,3 +172,7 @@ cython_debug/
172
 
173
  # PyPI configuration file
174
  .pypirc
 
 
 
 
 
172
 
173
  # PyPI configuration file
174
  .pypirc
175
+
176
+ evaluation_results/
177
+ images/
178
+ hf_cache/
README.md CHANGED
@@ -1,2 +1,10 @@
1
  # InferBench-
2
  Evaluate the quality and efficiency of image gen api's.
 
 
 
 
 
 
 
 
 
1
  # InferBench-
2
  Evaluate the quality and efficiency of image gen api's.
3
+
4
+ Install dependencies with conda like that:
5
+
6
+ conda env create -f environment.yml
7
+
8
+
9
+ Create .env file with all the files you will need.
10
+ python sample.py replicate draw_bench genai_bench geneval hps parti
api/__init__.py CHANGED
@@ -1,13 +1,26 @@
1
- from typing import Optional, Type
2
 
3
  from api.baseline import BaselineAPI
4
  from api.fireworks import FireworksAPI
5
  from api.flux import FluxAPI
6
  from api.pruna import PrunaAPI
 
7
  from api.replicate import ReplicateAPI
8
  from api.together import TogetherAPI
9
  from api.fal import FalAPI
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def create_api(api_type: str) -> FluxAPI:
12
  """
13
  Factory function to create API instances.
@@ -27,6 +40,8 @@ def create_api(api_type: str) -> FluxAPI:
27
  Raises:
28
  ValueError: If an invalid API type is provided
29
  """
 
 
30
  if api_type.startswith("pruna_"):
31
  speed_mode = api_type[6:] # Remove "pruna_" prefix
32
  return PrunaAPI(speed_mode)
 
1
+ from typing import Type
2
 
3
  from api.baseline import BaselineAPI
4
  from api.fireworks import FireworksAPI
5
  from api.flux import FluxAPI
6
  from api.pruna import PrunaAPI
7
+ from api.pruna_dev import PrunaDevAPI
8
  from api.replicate import ReplicateAPI
9
  from api.together import TogetherAPI
10
  from api.fal import FalAPI
11
 
12
+ __all__ = [
13
+ 'create_api',
14
+ 'FluxAPI',
15
+ 'BaselineAPI',
16
+ 'FireworksAPI',
17
+ 'PrunaAPI',
18
+ 'ReplicateAPI',
19
+ 'TogetherAPI',
20
+ 'FalAPI',
21
+ 'PrunaDevAPI',
22
+ ]
23
+
24
  def create_api(api_type: str) -> FluxAPI:
25
  """
26
  Factory function to create API instances.
 
40
  Raises:
41
  ValueError: If an invalid API type is provided
42
  """
43
+ if api_type == "pruna_dev":
44
+ return PrunaDevAPI()
45
  if api_type.startswith("pruna_"):
46
  speed_mode = api_type[6:] # Remove "pruna_" prefix
47
  return PrunaAPI(speed_mode)
api/fal.py CHANGED
@@ -7,7 +7,7 @@ import fal_client
7
  import requests
8
  from PIL import Image
9
 
10
- from flux import FluxAPI
11
 
12
 
13
  class FalAPI(FluxAPI):
 
7
  import requests
8
  from PIL import Image
9
 
10
+ from api.flux import FluxAPI
11
 
12
 
13
  class FalAPI(FluxAPI):
api/pruna_dev.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from dotenv import load_dotenv
7
+ import replicate
8
+
9
+ from api.flux import FluxAPI
10
+
11
+
12
+ class PrunaDevAPI(FluxAPI):
13
+ def __init__(self):
14
+ load_dotenv()
15
+ self._api_key = os.getenv("REPLICATE_API_TOKEN")
16
+ if not self._api_key:
17
+ raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
18
+
19
+ @property
20
+ def name(self) -> str:
21
+ return "pruna_dev"
22
+
23
+ def generate_image(self, prompt: str, save_path: Path) -> float:
24
+ start_time = time.time()
25
+ result = replicate.run(
26
+ "prunaai/flux.1-dev:938a4eb31a87d65fb7b23fc300fb5b7ab88a36844bb26e54e1d1dec7acf4eefe",
27
+ input={
28
+ "seed": 0,
29
+ "prompt": prompt,
30
+ "guidance": 3.5,
31
+ "num_outputs": 1,
32
+ "aspect_ratio": "1:1",
33
+ "output_format": "png",
34
+ "speed_mode": "Juiced 🔥 (default)",
35
+ "num_inference_steps": 28,
36
+ },
37
+ )
38
+ end_time = time.time()
39
+
40
+ if result:
41
+ self._save_image_from_result(result, save_path)
42
+ else:
43
+ raise Exception("No result returned from Replicate API")
44
+ return end_time - start_time
45
+
46
+ def _save_image_from_result(self, result: Any, save_path: Path):
47
+ save_path.parent.mkdir(parents=True, exist_ok=True)
48
+ with open(save_path, "wb") as f:
49
+ f.write(result.read())
benchmark/hps.py CHANGED
@@ -15,7 +15,7 @@ class HPSPrompts:
15
  self._size = 0
16
  for file in self.hps_prompt_files:
17
  category = file.replace('.json', '')
18
- with open(os.path.join('datasets/hps', file), 'r') as f:
19
  prompts = json.load(f)
20
  for i, prompt in enumerate(prompts):
21
  if i == 100:
@@ -26,7 +26,7 @@ class HPSPrompts:
26
 
27
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
28
  for filename, prompt in self.prompts.items():
29
- yield prompt, filename
30
 
31
  @property
32
  def name(self) -> str:
 
15
  self._size = 0
16
  for file in self.hps_prompt_files:
17
  category = file.replace('.json', '')
18
+ with open(os.path.join('downloads/hps', file), 'r') as f:
19
  prompts = json.load(f)
20
  for i, prompt in enumerate(prompts):
21
  if i == 100:
 
26
 
27
  def __iter__(self) -> Iterator[Tuple[str, Path]]:
28
  for filename, prompt in self.prompts.items():
29
+ yield prompt, Path(filename)
30
 
31
  @property
32
  def name(self) -> str:
benchmark/metrics/__init__.py CHANGED
@@ -6,7 +6,7 @@ from benchmark.metrics.clip_iqa import CLIPIQAMetric
6
  from benchmark.metrics.image_reward import ImageRewardMetric
7
  from benchmark.metrics.sharpness import SharpnessMetric
8
  from benchmark.metrics.vqa import VQAMetric
9
-
10
 
11
  def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
12
  """
@@ -20,7 +20,7 @@ def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAM
20
  - "image_reward"
21
  - "sharpness"
22
  - "vqa"
23
-
24
  Returns:
25
  An instance of the requested metric implementation
26
 
@@ -34,6 +34,7 @@ def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAM
34
  "image_reward": ImageRewardMetric,
35
  "sharpness": SharpnessMetric,
36
  "vqa": VQAMetric,
 
37
  }
38
 
39
  if metric_type not in metric_map:
 
6
  from benchmark.metrics.image_reward import ImageRewardMetric
7
  from benchmark.metrics.sharpness import SharpnessMetric
8
  from benchmark.metrics.vqa import VQAMetric
9
+ from benchmark.metrics.hps import HPSMetric
10
 
11
  def create_metric(metric_type: str) -> Type[ARNIQAMetric | CLIPMetric | CLIPIQAMetric | ImageRewardMetric | SharpnessMetric | VQAMetric]:
12
  """
 
20
  - "image_reward"
21
  - "sharpness"
22
  - "vqa"
23
+ - "hps"
24
  Returns:
25
  An instance of the requested metric implementation
26
 
 
34
  "image_reward": ImageRewardMetric,
35
  "sharpness": SharpnessMetric,
36
  "vqa": VQAMetric,
37
+ "hps": HPSMetric,
38
  }
39
 
40
  if metric_type not in metric_map:
benchmark/metrics/arniqa.py CHANGED
@@ -8,19 +8,20 @@ from torchmetrics.image.arniqa import ARNIQA
8
 
9
  class ARNIQAMetric:
10
  def __init__(self):
 
11
  self.metric = ARNIQA(
12
  regressor_dataset="koniq10k",
13
  reduction="mean",
14
  normalize=True,
15
  autocast=False
16
  )
17
-
18
  @property
19
  def name(self) -> str:
20
  return "arniqa"
21
 
22
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
23
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
24
- image_tensor = image_tensor.unsqueeze(0)
25
  score = self.metric(image_tensor)
26
  return {"arniqa": score.item()}
 
8
 
9
  class ARNIQAMetric:
10
  def __init__(self):
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  self.metric = ARNIQA(
13
  regressor_dataset="koniq10k",
14
  reduction="mean",
15
  normalize=True,
16
  autocast=False
17
  )
18
+ self.metric.to(self.device)
19
  @property
20
  def name(self) -> str:
21
  return "arniqa"
22
 
23
  def compute_score(self, image: Image.Image, prompt: str) -> Dict[str, float]:
24
  image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
25
+ image_tensor = image_tensor.unsqueeze(0).to(self.device)
26
  score = self.metric(image_tensor)
27
  return {"arniqa": score.item()}
benchmark/metrics/clip_iqa.py CHANGED
@@ -12,7 +12,7 @@ class CLIPIQAMetric:
12
  self.metric = CLIPImageQualityAssessment(
13
  model_name_or_path="clip_iqa",
14
  data_range=255.0,
15
- prompts=["quality"]
16
  )
17
  self.metric.to(self.device)
18
 
 
12
  self.metric = CLIPImageQualityAssessment(
13
  model_name_or_path="clip_iqa",
14
  data_range=255.0,
15
+ prompts=("quality",)
16
  )
17
  self.metric.to(self.device)
18
 
benchmark/metrics/vqa.py CHANGED
@@ -1,9 +1,7 @@
1
- import os
2
- import tempfile
3
  from typing import Dict
4
 
5
  import t2v_metrics
6
- from PIL import Image
7
 
8
  class VQAMetric:
9
  def __init__(self):
@@ -15,11 +13,8 @@ class VQAMetric:
15
 
16
  def compute_score(
17
  self,
18
- image: Image.Image,
19
  prompt: str,
20
  ) -> Dict[str, float]:
21
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
22
- image.save(tmp.name)
23
- score = self.metric(images=[tmp.name], texts=[prompt])
24
- os.unlink(tmp.name)
25
- return {"vqa_score": score[0][0].item()}
 
1
+ from pathlib import Path
 
2
  from typing import Dict
3
 
4
  import t2v_metrics
 
5
 
6
  class VQAMetric:
7
  def __init__(self):
 
13
 
14
  def compute_score(
15
  self,
16
+ image_path: Path,
17
  prompt: str,
18
  ) -> Dict[str, float]:
19
+ score = self.metric(images=[str(image_path)], texts=[prompt])
20
+ return {"vqa": score[0][0].item()}
 
 
 
environment.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: inferbench
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - python=3.12
7
+ - numpy
8
+ - opencv
9
+ - pillow
10
+ - python-dotenv
11
+ - requests
12
+ - tqdm
13
+ - pip
14
+ - pip:
15
+ - datasets>=3.5.0
16
+ - fal-client>=0.5.9
17
+ - hpsv2>=1.2.0
18
+ - huggingface-hub>=0.30.2
19
+ - image-reward>=1.5
20
+ - replicate>=1.0.4
21
+ - t2v-metrics>=1.2
22
+ - together>=1.5.5
23
+ - torch>=2.7.0
24
+ - torchmetrics>=1.7.1
25
+ - clip
26
+ - diffusers<=0.31
27
+ - piq>=0.8.0
evaluate.py CHANGED
@@ -2,10 +2,16 @@ import argparse
2
  import json
3
  from pathlib import Path
4
  from typing import Dict
 
5
 
6
  from benchmark import create_benchmark
7
  from benchmark.metrics import create_metric
 
8
  from PIL import Image
 
 
 
 
9
 
10
 
11
  def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
@@ -39,29 +45,31 @@ def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Pa
39
  "api": api_type,
40
  "benchmark": benchmark_type,
41
  "metrics": {metric: 0.0 for metric in benchmark.metrics},
42
- "avg_inference_time": 0.0,
43
  "total_images": len(metadata)
44
  }
 
45
 
46
- for entry in metadata:
47
  image_path = benchmark_dir / entry["filepath"]
48
  if not image_path.exists():
49
  continue
50
 
51
- image = Image.open(image_path)
52
-
53
  for metric_type, metric in metrics.items():
54
  try:
55
- score = metric.compute_score(image, entry["prompt"])
 
 
 
 
56
  results["metrics"][metric_type] += score[metric_type]
57
  except Exception as e:
58
  print(f"Error computing {metric_type} for {image_path}: {str(e)}")
59
 
60
- results["avg_inference_time"] += entry["inference_time"]
61
 
62
  for metric in results["metrics"]:
63
  results["metrics"][metric] /= len(metadata)
64
- results["avg_inference_time"] /= len(metadata)
65
 
66
  return results
67
 
 
2
  import json
3
  from pathlib import Path
4
  from typing import Dict
5
+ import warnings
6
 
7
  from benchmark import create_benchmark
8
  from benchmark.metrics import create_metric
9
+ import numpy as np
10
  from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+
14
+ warnings.filterwarnings("ignore", category=FutureWarning)
15
 
16
 
17
  def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
 
45
  "api": api_type,
46
  "benchmark": benchmark_type,
47
  "metrics": {metric: 0.0 for metric in benchmark.metrics},
 
48
  "total_images": len(metadata)
49
  }
50
+ inference_times = []
51
 
52
+ for entry in tqdm(metadata):
53
  image_path = benchmark_dir / entry["filepath"]
54
  if not image_path.exists():
55
  continue
56
 
 
 
57
  for metric_type, metric in metrics.items():
58
  try:
59
+ if metric_type == "vqa":
60
+ score = metric.compute_score(image_path, entry["prompt"])
61
+ else:
62
+ image = Image.open(image_path)
63
+ score = metric.compute_score(image, entry["prompt"])
64
  results["metrics"][metric_type] += score[metric_type]
65
  except Exception as e:
66
  print(f"Error computing {metric_type} for {image_path}: {str(e)}")
67
 
68
+ inference_times.append(entry["inference_time"])
69
 
70
  for metric in results["metrics"]:
71
  results["metrics"][metric] /= len(metadata)
72
+ results["median_inference_time"] = np.median(inference_times).item()
73
 
74
  return results
75
 
pyproject.toml DELETED
@@ -1,22 +0,0 @@
1
- [project]
2
- name = "inferbench"
3
- version = "0.1.0"
4
- requires-python = ">=3.12"
5
- dependencies = [
6
- "datasets>=3.5.0",
7
- "fal-client>=0.5.9",
8
- "hpsv2>=1.2.0",
9
- "huggingface-hub>=0.30.2",
10
- "image-reward>=1.5",
11
- "numpy>=2.2.5",
12
- "opencv-python>=4.11.0.86",
13
- "pillow>=11.2.1",
14
- "python-dotenv>=1.1.0",
15
- "replicate>=1.0.4",
16
- "requests>=2.32.3",
17
- "t2v-metrics>=1.2",
18
- "together>=1.5.5",
19
- "torch>=2.7.0",
20
- "torchmetrics>=1.7.1",
21
- "tqdm>=4.67.1",
22
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uv.lock DELETED
The diff for this file is too large to render. See raw diff