InferBench / evaluate.py
nifleisch
fix: fix several erros
4f41410
raw
history blame
3.85 kB
import argparse
import json
from pathlib import Path
from typing import Dict
import warnings
from benchmark import create_benchmark
from benchmark.metrics import create_metric
import numpy as np
from PIL import Image
from tqdm import tqdm
warnings.filterwarnings("ignore", category=FutureWarning)
def evaluate_benchmark(benchmark_type: str, api_type: str, images_dir: Path = Path("images")) -> Dict:
"""
Evaluate a benchmark's images using its specific metrics.
Args:
benchmark_type (str): Type of benchmark to evaluate
api_type (str): Type of API used to generate images
images_dir (Path): Base directory containing generated images
Returns:
Dict containing evaluation results
"""
benchmark = create_benchmark(benchmark_type)
benchmark_dir = images_dir / api_type / benchmark_type
metadata_file = benchmark_dir / "metadata.jsonl"
if not metadata_file.exists():
raise FileNotFoundError(f"No metadata file found for {api_type}/{benchmark_type}. Please run sample.py first.")
metadata = []
with open(metadata_file, "r") as f:
for line in f:
metadata.append(json.loads(line))
metrics = {metric_type: create_metric(metric_type) for metric_type in benchmark.metrics}
results = {
"api": api_type,
"benchmark": benchmark_type,
"metrics": {metric: 0.0 for metric in benchmark.metrics},
"total_images": len(metadata)
}
inference_times = []
for entry in tqdm(metadata):
image_path = benchmark_dir / entry["filepath"]
if not image_path.exists():
continue
for metric_type, metric in metrics.items():
try:
if metric_type == "vqa":
score = metric.compute_score(image_path, entry["prompt"])
else:
image = Image.open(image_path)
score = metric.compute_score(image, entry["prompt"])
results["metrics"][metric_type] += score[metric_type]
except Exception as e:
print(f"Error computing {metric_type} for {image_path}: {str(e)}")
inference_times.append(entry["inference_time"])
for metric in results["metrics"]:
results["metrics"][metric] /= len(metadata)
results["median_inference_time"] = np.median(inference_times).item()
return results
def main():
parser = argparse.ArgumentParser(description="Evaluate generated images using benchmark-specific metrics")
parser.add_argument("api_type", help="Type of API to evaluate")
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to evaluate")
args = parser.parse_args()
results_dir = Path("evaluation_results")
results_dir.mkdir(exist_ok=True)
results_file = results_dir / f"{args.api_type}.jsonl"
existing_results = set()
if results_file.exists():
with open(results_file, "r") as f:
for line in f:
result = json.loads(line)
existing_results.add(result["benchmark"])
for benchmark_type in args.benchmarks:
if benchmark_type in existing_results:
print(f"Skipping {args.api_type}/{benchmark_type} - already evaluated")
continue
try:
print(f"Evaluating {args.api_type}/{benchmark_type}")
results = evaluate_benchmark(benchmark_type, args.api_type)
# Append results to file
with open(results_file, "a") as f:
f.write(json.dumps(results) + "\n")
except Exception as e:
print(f"Error evaluating {args.api_type}/{benchmark_type}: {str(e)}")
if __name__ == "__main__":
main()