File size: 9,785 Bytes
ffd6b68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import tensorflow as tf
import os
import argparse
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm
import uuid
import random
# Parses command line arguments
def parse_arguments():
parser = argparse.ArgumentParser(description='Image Data Loader with Augmentation and Splits')
parser.add_argument('--path', type=str, required=True, help='Path to the folder containing images')
parser.add_argument('--dim', type=int, default=224, help='Required image dimension')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--target_folder', type=str, required=True, help='Folder to store the train, test, and val splits')
parser.add_argument('--augment_data', action='store_true', help='Apply data augmentation')
parser.add_argument('--balance', action='store_true', help='Balance the dataset')
parser.add_argument('--split_type', type=str, choices=['random', 'stratified'], default='random',
help='Type of data split (random or stratified)')
return parser.parse_args()
# Process the input images
def process_image(file_path, image_size):
image = tf.io.read_file(file_path)
image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
image = tf.image.resize(image, [image_size, image_size])
image = tf.clip_by_value(image, 0.0, 1.0)
return image
# Balances the images of a specific class
def balance_class_images(image_paths, labels, target_count, image_size, label, label_to_index, output_folder):
print(f"Balancing class '{label}'...")
label_idx = label_to_index.get(label, None)
if label_idx is None:
print(f"Label '{label}' not found in label_to_index.")
return [], []
image_paths = [img for img, lbl in zip(image_paths, labels) if lbl == label_idx]
num_images = len(image_paths)
print(f"Class '{label}' has {num_images} images before balancing.")
balanced_images = []
balanced_labels = []
original_count = num_images
synthetic_count = 0
if num_images > target_count:
balanced_images.extend(random.sample(image_paths, target_count))
balanced_labels.extend([label_idx] * target_count)
print(f"Removed {num_images - target_count} images from class '{label}'.")
elif num_images < target_count:
balanced_images.extend(image_paths)
balanced_labels.extend([label_idx] * num_images)
num_to_add = target_count - num_images
print(f"Class '{label}' needs {num_to_add} additional images for balancing.")
while num_to_add > 0:
img_path = random.choice(image_paths)
image = process_image(img_path, image_size)
for _ in range(min(num_to_add, 5)): # Use up to 5 augmentations per image
augmented_image = augment_image(image)
balanced_images.append(augmented_image)
balanced_labels.append(label_idx)
num_to_add -= 1
synthetic_count += 1
print(f"Added {synthetic_count} augmented images to class '{label}'.")
print(f"Class '{label}' has {len(balanced_images)} images after balancing.")
class_folder = os.path.join(output_folder, str(label_idx))
if not os.path.exists(class_folder):
os.makedirs(class_folder)
for i, img in enumerate(balanced_images):
file_name = f"{uuid.uuid4()}.png"
file_path = os.path.join(class_folder, file_name)
save_image(img, file_path)
print(f"Saved {len(balanced_images)} images for class '{label}' (Original: {original_count}, Synthetic: {synthetic_count}).")
return balanced_images, balanced_labels
# Saves an image to a file
def save_image(image, file_path):
if isinstance(image, str):
image = process_image(image, image_size)
if isinstance(image, tf.Tensor):
image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
image = tf.image.encode_png(image)
else:
raise ValueError("Expected image to be a TensorFlow tensor, but got a different type.")
tf.io.write_file(file_path, image)
# Augments an image with random transformations
def augment_image(image):
# Apply random augmentations using TensorFlow functions
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
image = tf.image.random_saturation(image, lower=0.9, upper=1.1)
image = tf.image.random_hue(image, max_delta=0.1)
return image
# Creates a list of data augmentation functions
def create_datagens():
return [augment_image]
# Balances the entire dataset by balancing each class
def balance_data(images, labels, target_count, image_size, unique_labels, label_to_index, output_folder):
print(f"Balancing data: Target count per class = {target_count}")
all_balanced_images = []
all_balanced_labels = []
for label in tqdm(unique_labels, desc="Balancing classes"):
num_images = len([img for img, lbl in zip(images, labels) if lbl == label_to_index.get(label, -1)])
balanced_images, balanced_labels = balance_class_images(
images, labels, target_count, image_size, label, label_to_index, output_folder
)
all_balanced_images.extend(balanced_images)
all_balanced_labels.extend(balanced_labels)
total_original_images = sum(1 for img in all_balanced_images if isinstance(img, str))
total_synthetic_images = len(all_balanced_images) - total_original_images
print(f"\nTotal saved images: {len(all_balanced_images)} (Original: {total_original_images}, Synthetic: {total_synthetic_images})")
return all_balanced_images, all_balanced_labels
# Augments an image using TensorFlow functions
def tf_augment_image(file_path, label):
image = tf.image.resize(tf.image.decode_jpeg(tf.io.read_file(file_path)), [image_size, image_size])
image = tf.cast(image, tf.float32) / 255.0
augmented_image = augment_image(image)
return augmented_image, label
def map_fn(file_path, label):
image, label = tf.py_function(tf_augment_image, [file_path, label], [tf.float32, tf.int32])
image.set_shape([image_size, image_size, 3])
label.set_shape([])
return image, label
# Loads images, splits them into train, validation, and test sets, and saves the splits
def load_and_save_splits(path, image_size, batch_size, balance, datagens, target_folder, split_type):
all_images = []
labels = []
for class_folder in os.listdir(path):
class_path = os.path.join(path, class_folder)
if os.path.isdir(class_path):
for img_file in os.listdir(class_path):
img_path = os.path.join(class_path, img_file)
all_images.append(img_path)
labels.append(class_folder) # Use the folder name as the label
print(f"Loaded {len(all_images)} images across {len(set(labels))} classes.")
print(f"Labels found: {set(labels)}") # Print unique labels
unique_labels = list(set(labels))
label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
encoded_labels = [label_to_index[label] for label in labels]
print(f"Label to index mapping: {label_to_index}")
if split_type == 'stratified':
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_indices, test_indices = next(sss.split(all_images, encoded_labels))
else: # random split
total_images = len(all_images)
indices = list(range(total_images))
random.shuffle(indices)
train_indices = indices[:int(0.8 * total_images)]
test_indices = indices[int(0.8 * total_images):]
train_files = [all_images[i] for i in train_indices]
train_labels = [encoded_labels[i] for i in train_indices]
test_files = [all_images[i] for i in test_indices]
test_labels = [encoded_labels[i] for i in test_indices]
# Create validation and test sets
sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
val_indices, test_indices = next(sss_val.split(test_files, test_labels))
val_files = [test_files[i] for i in val_indices]
val_labels = [test_labels[i] for i in val_indices]
test_files = [test_files[i] for i in test_indices]
test_labels = [test_labels[i] for i in test_indices]
# Save splits
for split_name, file_list, labels_list in [("train", train_files, train_labels), ("val", val_files, val_labels), ("test", test_files, test_labels)]:
split_folder = os.path.join(target_folder, split_name)
os.makedirs(split_folder, exist_ok=True)
with open(os.path.join(split_folder, f"{split_name}_list.txt"), 'w') as file_list_file:
for img_path, label in zip(file_list, labels_list):
label_folder = os.path.join(split_folder, str(label))
if not os.path.exists(label_folder):
os.makedirs(label_folder)
file_list_file.write(f"{img_path}\n")
save_image(img_path, os.path.join(label_folder, f"{uuid.uuid4()}.png"))
print(f"Saved splits: train: {len(train_files)}, val: {len(val_files)}, test: {len(test_files)}.")
# Main function to run the data loader
def main():
args = parse_arguments()
load_and_save_splits(args.path, args.dim, args.batch_size, args.balance, create_datagens(), args.target_folder, args.split_type)
if __name__ == "__main__":
main()
|