File size: 4,408 Bytes
2c50826
 
 
 
 
 
 
 
 
 
 
 
 
 
199a7d9
2c50826
 
199a7d9
2c50826
 
199a7d9
2c50826
199a7d9
2c50826
 
 
199a7d9
2c50826
 
 
 
 
 
 
199a7d9
9fa4df6
ac90c02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c50826
 
 
199a7d9
2c50826
 
 
 
 
 
 
199a7d9
9fa4df6
ac90c02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c50826
 
 
199a7d9
 
 
2c50826
 
199a7d9
2c50826
199a7d9
2c50826
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import json
from pathlib import Path
from typing import List

from tqdm import tqdm

from api import create_api
from benchmark import create_benchmark


def generate_images(api_type: str, benchmarks: List[str]):
    images_dir = Path("images")
    api = create_api(api_type)

    api_dir = images_dir / api_type
    api_dir.mkdir(parents=True, exist_ok=True)

    for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
        print(f"\nProcessing benchmark: {benchmark_type}")

        benchmark = create_benchmark(benchmark_type)

        if benchmark_type == "geneval":
            benchmark_dir = api_dir / benchmark_type
            benchmark_dir.mkdir(parents=True, exist_ok=True)

            metadata_file = benchmark_dir / "metadata.jsonl"
            existing_metadata = {}
            if metadata_file.exists():
                with open(metadata_file, "r") as f:
                    for line in f:
                        entry = json.loads(line)
                        existing_metadata[entry["filepath"]] = entry

            with open(metadata_file, "a") as f:
                for metadata, folder_name in tqdm(
                    benchmark,
                    desc=f"Generating images for {benchmark_type}",
                    leave=False,
                ):
                    sample_path = benchmark_dir / folder_name
                    samples_path = sample_path / "samples"
                    samples_path.mkdir(parents=True, exist_ok=True)
                    image_path = samples_path / "0000.png"

                    if image_path.exists():
                        continue

                    try:
                        inference_time = api.generate_image(
                            metadata["prompt"], image_path
                        )

                        metadata_entry = {
                            "filepath": str(image_path),
                            "prompt": metadata["prompt"],
                            "inference_time": inference_time,
                        }

                        f.write(json.dumps(metadata_entry) + "\n")

                    except Exception as e:
                        print(
                            f"\nError generating image for prompt: {metadata['prompt']}"
                        )
                        print(f"Error: {str(e)}")
                        continue
        else:
            benchmark_dir = api_dir / benchmark_type
            benchmark_dir.mkdir(parents=True, exist_ok=True)

            metadata_file = benchmark_dir / "metadata.jsonl"
            existing_metadata = {}
            if metadata_file.exists():
                with open(metadata_file, "r") as f:
                    for line in f:
                        entry = json.loads(line)
                        existing_metadata[entry["filepath"]] = entry

            with open(metadata_file, "a") as f:
                for prompt, image_path in tqdm(
                    benchmark,
                    desc=f"Generating images for {benchmark_type}",
                    leave=False,
                ):
                    if image_path in existing_metadata:
                        continue

                    full_image_path = benchmark_dir / image_path

                    if full_image_path.exists():
                        continue

                    try:
                        inference_time = api.generate_image(prompt, full_image_path)

                        metadata_entry = {
                            "filepath": str(image_path),
                            "prompt": prompt,
                            "inference_time": inference_time,
                        }

                        f.write(json.dumps(metadata_entry) + "\n")

                    except Exception as e:
                        print(f"\nError generating image for prompt: {prompt}")
                        print(f"Error: {str(e)}")
                        continue


def main():
    parser = argparse.ArgumentParser(
        description="Generate images for specified benchmarks using a given API"
    )
    parser.add_argument("api_type", help="Type of API to use for image generation")
    parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")

    args = parser.parse_args()

    generate_images(args.api_type, args.benchmarks)


if __name__ == "__main__":
    main()