File size: 4,254 Bytes
2c50826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import json
from pathlib import Path
from typing import List

from tqdm import tqdm

from api import create_api
from benchmark import create_benchmark


def generate_images(api_type: str, benchmarks: List[str]):
    images_dir = Path("images")
    api = create_api(api_type)
    
    api_dir = images_dir / api_type
    api_dir.mkdir(parents=True, exist_ok=True)
    
    for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
        print(f"\nProcessing benchmark: {benchmark_type}")
        
        benchmark = create_benchmark(benchmark_type)
        
        if benchmark_type == "geneval":
            benchmark_dir = api_dir / benchmark_type
            benchmark_dir.mkdir(parents=True, exist_ok=True)
            
            metadata_file = benchmark_dir / "metadata.jsonl"
            existing_metadata = {}
            if metadata_file.exists():
                with open(metadata_file, "r") as f:
                    for line in f:
                        entry = json.loads(line)
                        existing_metadata[entry["filepath"]] = entry
            
            for metadata, folder_name in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
                sample_path = benchmark_dir / folder_name
                samples_path = sample_path / "samples"
                samples_path.mkdir(parents=True, exist_ok=True)
                image_path = samples_path / "0000.png"
                
                if image_path.exists():
                    continue
                
                try:
                    inference_time = api.generate_image(metadata["prompt"], image_path)
                    
                    metadata_entry = {
                        "filepath": str(image_path),
                        "prompt": metadata["prompt"],
                        "inference_time": inference_time
                    }
                    
                    existing_metadata[str(image_path)] = metadata_entry
                    
                except Exception as e:
                    print(f"\nError generating image for prompt: {metadata['prompt']}")
                    print(f"Error: {str(e)}")
                    continue
        else:
            benchmark_dir = api_dir / benchmark_type
            benchmark_dir.mkdir(parents=True, exist_ok=True)
            
            metadata_file = benchmark_dir / "metadata.jsonl"
            existing_metadata = {}
            if metadata_file.exists():
                with open(metadata_file, "r") as f:
                    for line in f:
                        entry = json.loads(line)
                        existing_metadata[entry["filepath"]] = entry
            
            for prompt, image_path in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False):
                full_image_path = benchmark_dir / image_path
                
                if full_image_path.exists():
                    continue
                
                try:
                    inference_time = api.generate_image(prompt, full_image_path)
                    
                    metadata_entry = {
                        "filepath": str(image_path),
                        "prompt": prompt,
                        "inference_time": inference_time
                    }
                    
                    existing_metadata[str(image_path)] = metadata_entry
                    
                except Exception as e:
                    print(f"\nError generating image for prompt: {prompt}")
                    print(f"Error: {str(e)}")
                    continue
            
            with open(metadata_file, "w") as f:
                for entry in existing_metadata.values():
                    f.write(json.dumps(entry) + "\n")


def main():
    parser = argparse.ArgumentParser(description="Generate images for specified benchmarks using a given API")
    parser.add_argument("api_type", help="Type of API to use for image generation")
    parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")
    
    args = parser.parse_args()
    
    generate_images(args.api_type, args.benchmarks)


if __name__ == "__main__":
    main()