InferBench / sample.py
nifleisch
feat: add core logic for project
2c50826
raw
history blame
4.25 kB
import argparse
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
from api import create_api
from benchmark import create_benchmark
def generate_images(api_type: str, benchmarks: List[str]):
images_dir = Path("images")
api = create_api(api_type)
api_dir = images_dir / api_type
api_dir.mkdir(parents=True, exist_ok=True)
for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
print(f"\nProcessing benchmark: {benchmark_type}")
benchmark = create_benchmark(benchmark_type)
if benchmark_type == "geneval":
benchmark_dir = api_dir / benchmark_type
benchmark_dir.mkdir(parents=True, exist_ok=True)
metadata_file = benchmark_dir / "metadata.jsonl"
existing_metadata = {}
if metadata_file.exists():
with open(metadata_file, "r") as f:
for line in f:
entry = json.loads(line)
existing_metadata[entry["filepath"]] = entry
for metadata, folder_name in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
sample_path = benchmark_dir / folder_name
samples_path = sample_path / "samples"
samples_path.mkdir(parents=True, exist_ok=True)
image_path = samples_path / "0000.png"
if image_path.exists():
continue
try:
inference_time = api.generate_image(metadata["prompt"], image_path)
metadata_entry = {
"filepath": str(image_path),
"prompt": metadata["prompt"],
"inference_time": inference_time
}
existing_metadata[str(image_path)] = metadata_entry
except Exception as e:
print(f"\nError generating image for prompt: {metadata['prompt']}")
print(f"Error: {str(e)}")
continue
else:
benchmark_dir = api_dir / benchmark_type
benchmark_dir.mkdir(parents=True, exist_ok=True)
metadata_file = benchmark_dir / "metadata.jsonl"
existing_metadata = {}
if metadata_file.exists():
with open(metadata_file, "r") as f:
for line in f:
entry = json.loads(line)
existing_metadata[entry["filepath"]] = entry
for prompt, image_path in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
full_image_path = benchmark_dir / image_path
if full_image_path.exists():
continue
try:
inference_time = api.generate_image(prompt, full_image_path)
metadata_entry = {
"filepath": str(image_path),
"prompt": prompt,
"inference_time": inference_time
}
existing_metadata[str(image_path)] = metadata_entry
except Exception as e:
print(f"\nError generating image for prompt: {prompt}")
print(f"Error: {str(e)}")
continue
with open(metadata_file, "w") as f:
for entry in existing_metadata.values():
f.write(json.dumps(entry) + "\n")
def main():
parser = argparse.ArgumentParser(description="Generate images for specified benchmarks using a given API")
parser.add_argument("api_type", help="Type of API to use for image generation")
parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
args = parser.parse_args()
generate_images(args.api_type, args.benchmarks)
if __name__ == "__main__":
main()