File size: 7,719 Bytes
ffd6b68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import tensorflow as tf
import os
import argparse
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm # For progress display
import sys
import uuid # Import uuid for unique filename generation
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def parse_arguments():
parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
return parser.parse_args()
def create_datagens():
# Create a list of ImageDataGenerator objects for different augmentations
return [
ImageDataGenerator(rescale=1./255),
ImageDataGenerator(rotation_range=20),
ImageDataGenerator(width_shift_range=0.2),
ImageDataGenerator(height_shift_range=0.2),
ImageDataGenerator(horizontal_flip=True)
]
def process_image(file_path, image_size):
# Read, decode, resize, and normalize an image
file_path = file_path.numpy().decode('utf-8')
image = tf.io.read_file(file_path)
image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
image = tf.image.resize(image, [image_size, image_size])
image = tf.clip_by_value(image, 0.0, 1.0)
return image
def save_image(image, file_path):
# Convert image to uint8, encode as JPEG, and save to file
image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
image = tf.image.encode_jpeg(image)
tf.io.write_file(file_path, image)
def load_data(path, image_size, batch_size):
all_images = []
labels = []
# Load images and labels from the specified path
for subdir, _, files in os.walk(path):
label = os.path.basename(subdir)
for fname in files:
if fname.endswith(('.jpg', '.jpeg', '.png')):
all_images.append(os.path.join(subdir, fname))
labels.append(label)
unique_labels = set(labels)
print(f"Found {len(all_images)} images in {path}\n")
print(f"Labels found ({len(unique_labels)}): {unique_labels}\n")
# Raise an error if no images are found
if len(all_images) == 0:
raise ValueError(f"No images found in the specified path: {path}")
# Stratified splitting the dataset
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_indices, test_indices = next(sss.split(all_images, labels))
train_files = [all_images[i] for i in train_indices]
train_labels = [labels[i] for i in train_indices]
test_files = [all_images[i] for i in test_indices]
test_labels = [labels[i] for i in test_indices]
sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
val_indices, test_indices = next(sss_val.split(test_files, test_labels))
val_files = [test_files[i] for i in val_indices]
val_labels = [test_labels[i] for i in val_indices]
test_files = [test_files[i] for i in test_indices]
test_labels = [test_labels[i] for i in test_indices]
print(f"Data split into {len(train_files)} train, {len(val_files)} validation, and {len(test_files)} test images.\n")
# Define a function to load and augment images
def tf_load_and_augment_image(file_path, label):
image = tf.py_function(func=lambda x: process_image(x, image_size), inp=[file_path], Tout=tf.float32)
image.set_shape([image_size, image_size, 3])
return image, label
train_dataset = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
val_dataset = tf.data.Dataset.from_tensor_slices((val_files, val_labels))
# Create datasets from the loaded files and labels
test_dataset = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
train_dataset = train_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
val_dataset = val_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
test_dataset = test_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return train_dataset, val_dataset, test_dataset
def save_datasets_to_folders(dataset, folder_path, datagens=None):
# Save the dataset to specified folders with optional augmentations
if not os.path.exists(folder_path):
os.makedirs(folder_path)
count = 0
for batch_images, batch_labels in tqdm(dataset, desc=f"Saving to {folder_path}"):
for i in range(batch_images.shape[0]):
image = batch_images[i].numpy()
label = batch_labels[i].numpy().decode('utf-8')
label_folder = os.path.join(folder_path, label)
if not os.path.exists(label_folder):
os.makedirs(label_folder)
# Save the original image
file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
save_image(image, file_path)
count += 1
# Apply augmentations if datagens are provided
if datagens:
for datagen in datagens:
aug_image = datagen.random_transform(image)
file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
save_image(aug_image, file_path)
count += 1
print(f"Saved {count} images to {folder_path}\n")
return count
def main():
# Main function to parse arguments, load data, and save datasets
args = parse_arguments()
if not os.path.exists(args.target_folder):
os.makedirs(args.target_folder)
train_folder = os.path.join(args.target_folder, 'train')
val_folder = os.path.join(args.target_folder, 'val')
test_folder = os.path.join(args.target_folder, 'test')
datagens = create_datagens() if args.augment_data else None
train_dataset, val_dataset, test_dataset = load_data(
args.path,
args.dim,
args.batch_size
)
# Save datasets to respective folders and count images
train_count = save_datasets_to_folders(train_dataset, train_folder, datagens)
val_count = save_datasets_to_folders(val_dataset, val_folder)
test_count = save_datasets_to_folders(test_dataset, test_folder)
print(f"Train dataset saved to: {train_folder}\n")
print(f"Validation dataset saved to: {val_folder}\n")
print(f"Test dataset saved to: {test_folder}\n")
print('-'*20)
print(f"Number of images in training set: {train_count}\n")
print(f"Number of images in validation set: {val_count}\n")
print(f"Number of images in test set: {test_count}\n")
if __name__ == "__main__":
# Redirect stdout and stderr to avoid encoding issues
sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', encoding='utf-8', buffering=1)
main()
|