Spaces:
Running
Running
import argparse | |
import json | |
from pathlib import Path | |
from typing import List | |
from tqdm import tqdm | |
from api import create_api | |
from benchmark import create_benchmark | |
def generate_images(api_type: str, benchmarks: List[str]): | |
images_dir = Path("images") | |
api = create_api(api_type) | |
api_dir = images_dir / api_type | |
api_dir.mkdir(parents=True, exist_ok=True) | |
for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"): | |
print(f"\nProcessing benchmark: {benchmark_type}") | |
benchmark = create_benchmark(benchmark_type) | |
if benchmark_type == "geneval": | |
benchmark_dir = api_dir / benchmark_type | |
benchmark_dir.mkdir(parents=True, exist_ok=True) | |
metadata_file = benchmark_dir / "metadata.jsonl" | |
existing_metadata = {} | |
if metadata_file.exists(): | |
with open(metadata_file, "r") as f: | |
for line in f: | |
entry = json.loads(line) | |
existing_metadata[entry["filepath"]] = entry | |
for metadata, folder_name in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False): | |
sample_path = benchmark_dir / folder_name | |
samples_path = sample_path / "samples" | |
samples_path.mkdir(parents=True, exist_ok=True) | |
image_path = samples_path / "0000.png" | |
if image_path.exists(): | |
continue | |
try: | |
inference_time = api.generate_image(metadata["prompt"], image_path) | |
metadata_entry = { | |
"filepath": str(image_path), | |
"prompt": metadata["prompt"], | |
"inference_time": inference_time | |
} | |
existing_metadata[str(image_path)] = metadata_entry | |
except Exception as e: | |
print(f"\nError generating image for prompt: {metadata['prompt']}") | |
print(f"Error: {str(e)}") | |
continue | |
else: | |
benchmark_dir = api_dir / benchmark_type | |
benchmark_dir.mkdir(parents=True, exist_ok=True) | |
metadata_file = benchmark_dir / "metadata.jsonl" | |
existing_metadata = {} | |
if metadata_file.exists(): | |
with open(metadata_file, "r") as f: | |
for line in f: | |
entry = json.loads(line) | |
existing_metadata[entry["filepath"]] = entry | |
for prompt, image_path in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False): | |
full_image_path = benchmark_dir / image_path | |
if full_image_path.exists(): | |
continue | |
try: | |
inference_time = api.generate_image(prompt, full_image_path) | |
metadata_entry = { | |
"filepath": str(image_path), | |
"prompt": prompt, | |
"inference_time": inference_time | |
} | |
existing_metadata[str(image_path)] = metadata_entry | |
except Exception as e: | |
print(f"\nError generating image for prompt: {prompt}") | |
print(f"Error: {str(e)}") | |
continue | |
with open(metadata_file, "w") as f: | |
for entry in existing_metadata.values(): | |
f.write(json.dumps(entry) + "\n") | |
def main(): | |
parser = argparse.ArgumentParser(description="Generate images for specified benchmarks using a given API") | |
parser.add_argument("api_type", help="Type of API to use for image generation") | |
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run") | |
args = parser.parse_args() | |
generate_images(args.api_type, args.benchmarks) | |
if __name__ == "__main__": | |
main() | |