Spaces:
Running
Running
File size: 7,339 Bytes
c9dac35 dc6215b c9dac35 dc6215b ca96bd8 c9dac35 ca96bd8 c9dac35 dc6215b c9dac35 ca96bd8 c9dac35 ca96bd8 c9dac35 dc6215b ca96bd8 dc6215b ca96bd8 dc6215b c9dac35 ca96bd8 c9dac35 ca96bd8 c9dac35 ca96bd8 c9dac35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import os
import argparse
import shutil
import sys
from pathlib import Path
from PIL import Image
from caption import caption_images
def is_image_file(filename):
"""Check if a file is an allowed image type."""
allowed_extensions = ['.png', '.jpg', '.jpeg', '.webp']
return any(filename.lower().endswith(ext) for ext in allowed_extensions)
def is_unsupported_image(filename):
"""Check if a file is an image but not of an allowed type."""
unsupported_extensions = ['.bmp', '.gif', '.tiff', '.tif', '.ico', '.svg']
return any(filename.lower().endswith(ext) for ext in unsupported_extensions)
def is_text_file(filename):
"""Check if a file is a text file."""
return filename.lower().endswith('.txt')
def validate_input_directory(input_dir):
"""Validate that the input directory only contains allowed image formats."""
input_path = Path(input_dir)
unsupported_files = []
text_files = []
for file_path in input_path.iterdir():
if file_path.is_file():
if is_unsupported_image(file_path.name):
unsupported_files.append(file_path.name)
elif is_text_file(file_path.name):
text_files.append(file_path.name)
if unsupported_files:
print("Error: Unsupported image formats detected.")
print("Only .png, .jpg, .jpeg, and .webp files are allowed.")
print("The following files are not supported:")
for file in unsupported_files:
print(f" - {file}")
sys.exit(1)
if text_files:
print("Warning: Text files detected in the input directory.")
print("The following text files will be overwritten:")
for file in text_files:
print(f" - {file}")
def collect_images_by_category(input_path):
"""Collect all valid images and group them by category."""
images_by_category = {}
image_paths_by_category = {}
for file_path in input_path.iterdir():
if file_path.is_file() and is_image_file(file_path.name):
try:
# Load the image
image = Image.open(file_path).convert("RGB")
# Determine the category from the filename
category = file_path.stem.rsplit('_', 1)[0]
# Add image to the appropriate category
if category not in images_by_category:
images_by_category[category] = []
image_paths_by_category[category] = []
images_by_category[category].append(image)
image_paths_by_category[category].append(file_path)
except Exception as e:
print(f"Error loading {file_path.name}: {e}")
return images_by_category, image_paths_by_category
def process_by_category(images_by_category, image_paths_by_category, input_path, output_path):
"""Process images in batches by category."""
processed_count = 0
for category, images in images_by_category.items():
image_paths = image_paths_by_category[category]
try:
# Generate captions for the entire category using batch mode
captions = caption_images(images, category=category, batch_mode=True)
write_captions(image_paths, captions, input_path, output_path)
processed_count += len(images)
except Exception as e:
print(f"Error generating captions for category '{category}': {e}")
return processed_count
def process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path):
"""Process all images at once."""
all_images = [img for imgs in images_by_category.values() for img in imgs]
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
processed_count = 0
try:
captions = caption_images(all_images, batch_mode=False)
write_captions(all_image_paths, captions, input_path, output_path)
processed_count += len(all_images)
except Exception as e:
print(f"Error generating captions: {e}")
return processed_count
def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
"""Process all images in the input directory and generate captions."""
input_path = Path(input_dir)
output_path = Path(output_dir) if output_dir else input_path
# Validate the input directory first
validate_input_directory(input_dir)
# Create output directory if it doesn't exist
os.makedirs(output_path, exist_ok=True)
# Collect images by category
images_by_category, image_paths_by_category = collect_images_by_category(input_path)
# Log the number of images found
total_images = sum(len(images) for images in images_by_category.values())
print(f"Found {total_images} images to process.")
if not total_images:
print("No valid images found to process.")
return
# Process images based on batch setting
if batch_images:
processed_count = process_by_category(images_by_category, image_paths_by_category, input_path, output_path)
else:
processed_count = process_all_at_once(images_by_category, image_paths_by_category, input_path, output_path)
print(f"\nProcessing complete. {processed_count} images were captioned.")
def write_captions(image_paths, captions, input_path, output_path):
"""Helper function to write captions to files."""
for file_path, caption in zip(image_paths, captions):
try:
# Create caption file path (same name but with .txt extension)
caption_filename = file_path.stem + ".txt"
caption_path = input_path / caption_filename
# Write caption to file
with open(caption_path, 'w', encoding='utf-8') as f:
f.write(caption)
# If output directory is different from input, copy files
if output_path != input_path:
# Copy image to output directory
shutil.copy2(file_path, output_path / file_path.name)
# Copy caption to output directory
shutil.copy2(caption_path, output_path / caption_filename)
print(f"Processed {file_path.name} → {caption_filename}")
except Exception as e:
print(f"Error processing {file_path.name}: {e}")
def main():
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
parser.add_argument('--input', type=str, required=True, help='Directory containing images')
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
args = parser.parse_args()
# Validate input directory
if not os.path.isdir(args.input):
print(f"Error: Input directory '{args.input}' does not exist.")
return
# Process images
process_images(args.input, args.output, args.fix_outfit, args.batch_images)
if __name__ == "__main__":
main() |