File size: 7,719 Bytes
ffd6b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import tensorflow as tf
import os
import argparse
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm  # For progress display
import sys
import uuid  # Import uuid for unique filename generation
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def parse_arguments():
    parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
    parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
    parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
    parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
    parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
    return parser.parse_args()

def create_datagens():
    # Create a list of ImageDataGenerator objects for different augmentations
    return [
        ImageDataGenerator(rescale=1./255),
        ImageDataGenerator(rotation_range=20),
        ImageDataGenerator(width_shift_range=0.2),
        ImageDataGenerator(height_shift_range=0.2),
        ImageDataGenerator(horizontal_flip=True)
    ]

def process_image(file_path, image_size):
    # Read, decode, resize, and normalize an image
    file_path = file_path.numpy().decode('utf-8')
    image = tf.io.read_file(file_path)
    image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
    image = tf.image.resize(image, [image_size, image_size])
    image = tf.clip_by_value(image, 0.0, 1.0)
    return image

def save_image(image, file_path):
    # Convert image to uint8, encode as JPEG, and save to file
    image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
    image = tf.image.encode_jpeg(image)
    tf.io.write_file(file_path, image)

def load_data(path, image_size, batch_size):
    all_images = []
    labels = []
    # Load images and labels from the specified path

    for subdir, _, files in os.walk(path):
        label = os.path.basename(subdir)
        for fname in files:
            if fname.endswith(('.jpg', '.jpeg', '.png')):
                all_images.append(os.path.join(subdir, fname))
                labels.append(label)
    
    unique_labels = set(labels)
    print(f"Found {len(all_images)} images in {path}\n")
    print(f"Labels found ({len(unique_labels)}): {unique_labels}\n")
    
    # Raise an error if no images are found
    if len(all_images) == 0:
        raise ValueError(f"No images found in the specified path: {path}")

    # Stratified splitting the dataset
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    train_indices, test_indices = next(sss.split(all_images, labels))
    
    train_files = [all_images[i] for i in train_indices]
    train_labels = [labels[i] for i in train_indices]
    test_files = [all_images[i] for i in test_indices]
    test_labels = [labels[i] for i in test_indices]
    
    sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
    val_indices, test_indices = next(sss_val.split(test_files, test_labels))
    
    val_files = [test_files[i] for i in val_indices]
    val_labels = [test_labels[i] for i in val_indices]
    test_files = [test_files[i] for i in test_indices]
    test_labels = [test_labels[i] for i in test_indices]

    print(f"Data split into {len(train_files)} train, {len(val_files)} validation, and {len(test_files)} test images.\n")

    # Define a function to load and augment images
    def tf_load_and_augment_image(file_path, label):
        image = tf.py_function(func=lambda x: process_image(x, image_size), inp=[file_path], Tout=tf.float32)
        image.set_shape([image_size, image_size, 3])
        return image, label

    train_dataset = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
    val_dataset = tf.data.Dataset.from_tensor_slices((val_files, val_labels))
    # Create datasets from the loaded files and labels
    test_dataset = tf.data.Dataset.from_tensor_slices((test_files, test_labels))

    train_dataset = train_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
    val_dataset = val_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))
    test_dataset = test_dataset.map(lambda x, y: tf_load_and_augment_image(x, y))

    train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    test_dataset = test_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return train_dataset, val_dataset, test_dataset

def save_datasets_to_folders(dataset, folder_path, datagens=None):
    # Save the dataset to specified folders with optional augmentations
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    
    count = 0
    for batch_images, batch_labels in tqdm(dataset, desc=f"Saving to {folder_path}"):
        for i in range(batch_images.shape[0]):
            image = batch_images[i].numpy()
            label = batch_labels[i].numpy().decode('utf-8')
            label_folder = os.path.join(folder_path, label)
            if not os.path.exists(label_folder):
                os.makedirs(label_folder)

            # Save the original image
            file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
            save_image(image, file_path)
            count += 1

            # Apply augmentations if datagens are provided
            if datagens:
                for datagen in datagens:
                    aug_image = datagen.random_transform(image)
                    file_path = os.path.join(label_folder, f"{uuid.uuid4().hex}.jpg")
                    save_image(aug_image, file_path)
                    count += 1
    
    print(f"Saved {count} images to {folder_path}\n")
    return count

def main():
    # Main function to parse arguments, load data, and save datasets
    args = parse_arguments()
    
    if not os.path.exists(args.target_folder):
        os.makedirs(args.target_folder)
        
    train_folder = os.path.join(args.target_folder, 'train')
    val_folder = os.path.join(args.target_folder, 'val')
    test_folder = os.path.join(args.target_folder, 'test')
    
    datagens = create_datagens() if args.augment_data else None

    train_dataset, val_dataset, test_dataset = load_data(
        args.path,
        args.dim,
        args.batch_size
    )
    
    # Save datasets to respective folders and count images
    train_count = save_datasets_to_folders(train_dataset, train_folder, datagens)
    val_count = save_datasets_to_folders(val_dataset, val_folder)
    test_count = save_datasets_to_folders(test_dataset, test_folder)
    
    print(f"Train dataset saved to: {train_folder}\n")
    print(f"Validation dataset saved to: {val_folder}\n")
    print(f"Test dataset saved to: {test_folder}\n")
    
    print('-'*20)

    print(f"Number of images in training set: {train_count}\n")
    print(f"Number of images in validation set: {val_count}\n")
    print(f"Number of images in test set: {test_count}\n")

if __name__ == "__main__":
    # Redirect stdout and stderr to avoid encoding issues
    sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
    sys.stderr = open(sys.stderr.fileno(), mode='w', encoding='utf-8', buffering=1)
    main()