SONAR-Image-Classifier / classify_image_and_explain.py
Purushothamann's picture
Upload 9 files
ffd6b68 verified
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img, load_img
from lime.lime_image import LimeImageExplainer, SegmentationAlgorithm
import matplotlib.pyplot as plt
from PIL import Image
import argparse
import shap
import cv2
import pickle
image_counter = 0
temp_folder = "temp_data"
output_folder = "explanations"
# Load the model and extract relevant details
def load_model_details(model_path):
if model_path.endswith('.keras'):
print("Loading .keras format model...")
model = tf.keras.models.load_model(model_path, compile=False)
elif model_path.endswith('.h5'):
print("Loading .h5 format model...")
model = tf.keras.models.load_model(model_path, compile=False)
else:
print("Loading SavedModel using TFSMLayer...")
model = tf.keras.Sequential([
tf.keras.layers.TFSMLayer(model_path, call_endpoint='serving_default')
])
input_shape = model.input_shape[1:3]
last_conv_layer_name = None
for layer in reversed(model.layers):
if isinstance(layer, tf.keras.layers.Conv2D):
last_conv_layer_name = layer.name
break
print(f"Model loaded with input shape: {input_shape} and last conv layer: {last_conv_layer_name}")
return model, last_conv_layer_name, input_shape
# Load the label encoder based on the training directory
def load_label_encoder(train_directory):
labels = sorted(os.listdir(train_directory))
label_encoder = {i: label for i, label in enumerate(labels)}
print(f"Label encoder created: {label_encoder}")
return label_encoder
def load_and_preprocess_image(filename, image_size):
# Load and preprocess the image for model input
print(f"Loading and preprocessing image from: {filename}")
image = tf.io.read_file(filename)
image = tf.image.decode_image(image, channels=3)
if not tf.executing_eagerly():
image.set_shape([None, None, 3])
image = tf.image.resize(image, [image_size[0], image_size[1]])
image = image / 255.0
image.set_shape([image_size[0], image_size[1], 3])
return image
# Create a dataset from the training directory
def create_dataset(data_dir, labels, image_size, batch_size):
print(f"Creating dataset from directory: {data_dir}")
image_files = []
image_labels = []
for label in labels:
label_dir = os.path.join(data_dir, label)
for image_file in os.listdir(label_dir):
image_files.append(os.path.join(label_dir, image_file))
image_labels.append(label)
label_map = {label: idx for idx, label in enumerate(labels)}
image_labels = [label_map[label] for label in image_labels]
dataset = tf.data.Dataset.from_tensor_slices((image_files, image_labels))
dataset = dataset.map(lambda x, y: (load_and_preprocess_image(x, image_size), y))
dataset = dataset.shuffle(buffer_size=len(image_files))
dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
print("Dataset created and batched")
return dataset
# Save preprocessed data (images and labels) to a file
def save_preprocessed_data(X_train, y_train, file_path):
print(f"Saving preprocessed data to: {file_path}")
with open(file_path, 'wb') as file:
pickle.dump((X_train, y_train), file)
def load_preprocessed_data(file_path):
print(f"Loading preprocessed data from: {file_path}")
with open(file_path, 'rb') as file:
return pickle.load(file)
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
# Generate a Grad-CAM heatmap for the given image and model
grad_model = tf.keras.models.Model(
inputs=model.inputs, outputs=[model.get_layer(last_conv_layer_name).output, model.output]
)
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(img_array)
preds = tf.convert_to_tensor(preds)
class_channel = preds[:, pred_index]
# if pred_index is None:
# pred_index = tf.argmax(preds[0]) # Default to the class with the highest probability
# pred_index = tf.squeeze(pred_index) # Ensure pred_index is a scalar tensor
# if tf.executing_eagerly():
# pred_index = pred_index.numpy() # Convert to a NumPy array
# pred_index = int(pred_index) # Convert to a Python integer
# class_channel = preds[0][pred_index]
grads = tape.gradient(class_channel, last_conv_layer_output)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
return heatmap.numpy()
def save_and_display_gradcam(array, heatmap, alpha=0.8):
# Save and display the Grad-CAM heatmap overlaid on the original image
print("Saving and displaying Grad-CAM result...")
heatmap = np.uint8(255 * heatmap)
jet = plt.cm.jet
jet_colors = jet(np.arange(256))[:, :3]
jet_heatmap = jet_colors[heatmap]
jet_heatmap = array_to_img(jet_heatmap)
jet_heatmap = jet_heatmap.resize((array.shape[1], array.shape[0]))
jet_heatmap = img_to_array(jet_heatmap)
superimposed_img = jet_heatmap * alpha + array
superimposed_img = array_to_img(superimposed_img)
return superimposed_img
def generate_splime_mask_top_n(img_array, model, explainer, top_n=1, num_features=100, num_samples=300):
# Generate a SP-LIME mask for the given image and model
# Use superpixel segmentation for SP-LIME
segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, max_dist=200, ratio=0.2)
explanation_instance = explainer.explain_instance(
img_array, model.predict, top_labels=top_n, hide_color=0,
num_samples=num_samples, num_features=num_features, segmentation_fn=segmentation_fn
)
explanation_mask = explanation_instance.get_image_and_mask(
explanation_instance.top_labels[0], positive_only=False,
num_features=num_features, hide_rest=True
)[1]
# Ensure mask is in the same shape as the input image
mask = np.zeros_like(img_array) # Create a mask of the same shape as img_array
mask[explanation_mask == 1] = img_array[explanation_mask == 1] # Overlay highlighted regions
# Set non-highlighted areas to white
mask = np.where(explanation_mask[:, :, np.newaxis] == 1, mask, 1.0)
return mask, explanation_instance
def explain_image_shap(img, model, class_names, top_prediction, max_evals=1000, batch_size=50):
# Generate SHAP explanations for the given image and model
masker = shap.maskers.Image("inpaint_telea", img[0].shape) # Update if necessary
# Define a function to predict probabilities from the model
def f(X):
return model.predict(X)
# Create the SHAP explainer
explainer = shap.Explainer(f, masker, output_names=class_names)
# Get SHAP values
shap_values = explainer(img, max_evals=max_evals, batch_size=batch_size, outputs=shap.Explanation.argsort.flip[:1])
return shap_values
def classify_image_and_explain(image_path, model_path, train_directory, num_samples, num_features, segmentation_alg, kernel_size, max_dist, ratio, max_evals, batch_size, explainer_types, output_folder):
# Main function to classify the image and generate explanations
global image_counter
if output_folder is None:
output_folder = "explanations"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
model, last_conv_layer_name, input_shape = load_model_details(model_path)
label_encoder = load_label_encoder(train_directory)
labels = list(label_encoder.values())
# Load the image
image = load_img(image_path, target_size=input_shape)
if image.mode != 'RGB':
image = image.convert('RGB')
array = img_to_array(image)
img_array = array / 255.0
img_array = np.expand_dims(img_array, axis=0)
# Predict the class of the image
predictions = model.predict(img_array)
top_prediction = np.argmax(predictions[0])
top_label = label_encoder[top_prediction]
print(f"Prediction: {top_label} with probability {predictions[0][top_prediction]:.4f}")
# Generate explanations based on user-specified types
if 'gradcam' in explainer_types:
model.layers[-1].activation = None
heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)
gradcam_image = save_and_display_gradcam(img_to_array(image), heatmap)
gradcam_image.save(os.path.join(output_folder, f"gradcam_{image_counter}.png"))
if 'lime' in explainer_types:
# SPLIME Explanation
explainer = LimeImageExplainer()
splime_mask, explanation_instance = generate_splime_mask_top_n(img_array[0], model, explainer, top_n=1, num_features=num_features, num_samples=num_samples)
# Ensure splime_mask is in [0, 1] range before saving
splime_mask = np.clip(splime_mask, 0, 1)
plt.imsave(os.path.join(output_folder, f"splime_{image_counter}.png"), splime_mask)
if 'shap' in explainer_types:
custom_image = img_to_array(image) / 255.0 # Preprocess image for SHAP
shap_values = explain_image_shap(custom_image.reshape(1, *custom_image.shape), model, labels, top_prediction, max_evals=max_evals, batch_size=batch_size)
shap.image_plot(shap_values[0], custom_image, labels=[top_label], show=False)
plt.savefig(os.path.join(output_folder, f"shap_{image_counter}.png"))
#plt.show()
plt.close()
print("Image classification and explanation process completed.")
image_counter += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Image classification and explanation script")
parser.add_argument("--image_path", type=str, required=True, help="Path to the input image")
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
parser.add_argument("--train_directory", type=str, required=True, help="Directory containing training images")
parser.add_argument("--num_samples", type=int, default=300, help="Number of samples for LIME")
parser.add_argument("--num_features", type=int, default=100, help="Number of features for LIME")
parser.add_argument("--segmentation_alg", type=str, default='quickshift', help="Segmentation algorithm for LIME (options: quickshift, slic)")
parser.add_argument("--kernel_size", type=int, default=4, help="Kernel size for segmentation algorithm")
parser.add_argument("--max_dist", type=int, default=200, help="Max distance for segmentation algorithm")
parser.add_argument("--ratio", type=float, default=0.2, help="Ratio for segmentation algorithm")
parser.add_argument("--max_evals", type=int, default=400, help="Maximum evaluations for SHAP")
parser.add_argument("--batch_size", type=int, default=50, help="Batch size for SHAP")
parser.add_argument("--explainer_types", type=str, default='all', help="Comma-separated list of explainers to use (options: lime, shap, gradcam). Use 'all' to include all three.")
parser.add_argument("--output_folder", type=str, default=None, help="Output folder for explanations")
args = parser.parse_args()
explainer_types = args.explainer_types.split(',') if args.explainer_types != 'all' else ['lime', 'shap', 'gradcam']
classify_image_and_explain(
args.image_path, args.model_path, args.train_directory, args.num_samples,
args.num_features, args.segmentation_alg, args.kernel_size, args.max_dist,
args.ratio, args.max_evals, args.batch_size, explainer_types, args.output_folder
)