# Developed by Mohammad Khalooei # More information and contact: https://github.com/khalooei/LSA import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms from torchvision.models import vgg16, vgg19, googlenet, resnet18 import timm import numpy as np import matplotlib.pyplot as plt from torchattacks import FGSM, PGD, APGD import os import time from datetime import datetime import gradio as gr class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 4 * 4, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) self.relu = nn.ReLU() self.pool = nn.MaxPool2d(2, 2) def forward(self, x, return_all=False): outputs = [] x1 = self.pool(self.relu(self.conv1(x))) outputs.append(x1) x2 = self.pool(self.relu(self.conv2(x1))) outputs.append(x2) x2_flat = x2.view(-1, 16 * 4 * 4) x3 = self.relu(self.fc1(x2_flat)) outputs.append(x3) x4 = self.relu(self.fc2(x3)) outputs.append(x4) x5 = self.fc3(x4) outputs.append(x5) if return_all: return outputs else: return x5 def salt_pepper_noise(images, prob=0.01, device='cuda'): batch_smap = torch.rand_like(images) < prob / 2 pepper = torch.rand_like(images) < prob / 2 noisy = images.clone() noisy[batch_smap] = 1.0 noisy[pepper] = 0.0 return torch.clamp(noisy, 0, 1) def pepper_statistical_noise(images, prob=0.01, device='cuda'): pepper = torch.rand_like(images) < prob noisy = images.clone() noisy[pepper] = 0.0 return torch.clamp(noisy, 0, 1) def get_layer_outputs(model, input_tensor): outputs = [] def hook(module, input, output): outputs.append(output) hooks = [] for layer in model.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)): hooks.append(layer.register_forward_hook(hook)) model.eval() with torch.no_grad(): model(input_tensor) for hook in hooks: hook.remove() return outputs def compute_mvl(model, clean_images, adv_images, device='cuda'): model.eval() with torch.no_grad(): try: clean_outputs = model(clean_images, return_all=True) adv_outputs = model(adv_images, return_all=True) except TypeError: clean_outputs = get_layer_outputs(model, clean_images) adv_outputs = get_layer_outputs(model, adv_images) mvl_list = [] for clean_out, adv_out in zip(clean_outputs, adv_outputs): if clean_out.ndim == 4: diff = torch.norm(clean_out - adv_out, p=2, dim=(1,2,3)) clean_norm = torch.norm(clean_out, p=2, dim=(1,2,3)) else: diff = torch.norm(clean_out - adv_out, p=2, dim=1) clean_norm = torch.norm(clean_out, p=2, dim=1) mvl = diff / (clean_norm + 1e-8) mvl_list.append(mvl.mean().item()) return mvl_list def get_model_stats(model): param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) layer_count = len([m for m in model.modules() if isinstance(m, (nn.Conv2d, nn.Linear))]) return param_count, layer_count def modify_model(model, model_name): if model_name.startswith('VGG'): model.classifier[6] = nn.Linear(4096, 10) elif model_name == 'GoogLeNet': model.fc = nn.Linear(1024, 10) elif model_name == 'ResNet18': model.fc = nn.Linear(512, 10) elif model_name == 'WideResNet': model.fc = nn.Linear(2048, 10) elif model_name == 'DenseNet121': model.classifier = nn.Linear(model.classifier.in_features, 10) elif model_name == 'MobileNetV2': if isinstance(model.classifier, nn.Sequential): model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10) else: model.classifier = nn.Linear(model.classifier.in_features, 10) elif model_name == 'EfficientNet-B0': model.classifier = nn.Linear(model.classifier.in_features, 10) return model def get_models_for_dataset(dataset_name): if dataset_name == 'MNIST': return ['LeNet'] elif dataset_name == 'CIFAR-10': return [ 'VGG16', 'VGG19', 'GoogLeNet', 'ResNet18', 'WideResNet', 'DenseNet121', 'MobileNetV2', 'EfficientNet-B0' ] else: return [] def get_dataset_and_transform(dataset_name): if dataset_name == 'MNIST': transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) else: transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) return dataset, transform def initialize_model(model_name, device): if model_name == 'LeNet': model = LeNet() elif model_name == 'VGG16': model = modify_model(vgg16(weights='IMAGENET1K_V1'), model_name) elif model_name == 'VGG19': model = modify_model(vgg19(weights='IMAGENET1K_V1'), model_name) elif model_name == 'GoogLeNet': model = modify_model(googlenet(weights='IMAGENET1K_V1'), model_name) elif model_name == 'ResNet18': model = modify_model(resnet18(weights='IMAGENET1K_V1'), model_name) elif model_name == 'WideResNet': model = modify_model(timm.create_model('wide_resnet50_2', pretrained=True), model_name) elif model_name == 'DenseNet121': model = modify_model(timm.create_model('densenet121', pretrained=True), model_name) elif model_name == 'MobileNetV2': model = modify_model(timm.create_model('mobilenetv2_100', pretrained=True), model_name) elif model_name == 'EfficientNet-B0': model = modify_model(timm.create_model('efficientnet_b0', pretrained=True), model_name) else: raise ValueError(f"Unknown model {model_name}") return model.to(device) def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'): start_time = time.time() logs = ["BSM:: experiment is being started ..."] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logs.append(f"Loading {dataset_name} dataset...") dataset, _ = get_dataset_and_transform(dataset_name) testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False) logs.append(f"{dataset_name} dataset loaded with {len(testloader)} batches.") logs.append(f"Initializing model {model_name} on {device}...") model = initialize_model(model_name, device) logs.append(f"Model {model_name} initialized.") param_count, layer_count = get_model_stats(model) logs.append(f"Model stats: Parameters = {param_count}, Layers = {layer_count}") all_attacks = { 'FGSM': FGSM(model, eps=0.03), 'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True), 'APGD': APGD(model, eps=0.03, steps=100, loss='ce'), 'Salt & Pepper': lambda x, y: salt_pepper_noise(x, prob=0.01, device=device), 'Pepper Statistical': lambda x, y: pepper_statistical_noise(x, prob=0.01, device=device) } attacks = {name: attack for name, attack in all_attacks.items() if name in selected_attacks} if not attacks: logs.append("Error: No valid attacks selected") return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)] logs.append(f"Selected attacks: {', '.join(attacks.keys())}") timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}") os.makedirs(output_dir, exist_ok=True) logs.append(f"Output directory created: {output_dir}") results = {atk: {'cm': [], 'mvl': []} for atk in attacks} for i, (images, labels) in enumerate(testloader): if i >= num_batches: logs.append(f"Reached batch limit: {num_batches}") break images, labels = images.to(device), labels.to(device) logs.append(f"Processing batch {i+1}/{num_batches}...") for atk_name, atk in attacks.items(): logs.append(f" Running attack: {atk_name} on batch {i+1}") adv_images = atk(images, labels) mvl_vals = compute_mvl(model, images, adv_images, device) results[atk_name]['mvl'].append(mvl_vals) batch_cm = np.mean(mvl_vals) results[atk_name]['cm'].append(batch_cm) logs.append(f" Attack {atk_name}: batch CM={batch_cm:.6f}") logs.append("Finished processing batches, computing statistics...") cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks} cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks} plt.figure(figsize=(8,6)) attack_names = list(attacks.keys()) means = [cm_means[a] for a in attack_names] stds = [cm_stds[a] for a in attack_names] x = np.arange(len(attack_names)) plt.bar(x, means, yerr=stds, capsize=5) plt.xticks(x, attack_names, rotation=45) plt.ylabel("CM (Relative Error)") plt.title(f"CM for {model_name} ({dataset_name})") plt.tight_layout() cm_plot_path = os.path.join(output_dir, "cm_plot.png") plt.savefig(cm_plot_path) plt.close() logs.append(f"Saved CM plot: {cm_plot_path}") mvl_plot_paths = [] colors = ['skyblue', 'lightgreen', 'coral', 'lightgray', 'purple'] for i, atk in enumerate(attack_names): mvl_arr = np.array(results[atk]['mvl']) mean_vals = np.mean(mvl_arr, axis=0) std_vals = np.std(mvl_arr, axis=0) layers = [f"Layer {j+1}" for j in range(len(mean_vals))] plt.figure(figsize=(8,6)) plt.plot(layers, mean_vals, marker='o', color=colors[i % len(colors)]) plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[i % len(colors)], alpha=0.3) plt.title(f"MVL per Layer - {atk}") plt.ylabel("MVL (Mean ± Std)") plt.xticks(rotation=45) plt.grid(True) plt.tight_layout() path = os.path.join(output_dir, f"mvl_{atk.lower().replace(' ', '_')}.png") plt.savefig(path) plt.close() mvl_plot_paths.append(path) logs.append(f"Saved MVL plot for {atk}: {path}") plt.figure(figsize=(10,6)) for i, atk in enumerate(attack_names): mvl_arr = np.array(results[atk]['mvl']) mean_vals = np.mean(mvl_arr, axis=0) std_vals = np.std(mvl_arr, axis=0) layers = [f"Layer {j+1}" for j in range(len(mean_vals))] plt.plot(layers, mean_vals, marker='o', color=colors[i % len(colors)], label=atk) plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[i % len(colors)], alpha=0.3) plt.title(f"Integrated MVL - {model_name}") plt.ylabel("MVL (Mean ± Std)") plt.xticks(rotation=45) plt.legend() plt.grid(True) plt.tight_layout() integrated_mvl_plot_path = os.path.join(output_dir, "integrated_mvl.png") plt.savefig(integrated_mvl_plot_path) plt.close() logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}") processing_time = time.time() - start_time logs.append(f"Processing completed in {processing_time:.2f} seconds") stats = { 'Dataset': dataset_name, 'Model': model_name, 'Parameters': param_count, 'Layers': layer_count, 'Batches': num_batches, 'Attacks': ', '.join(attack_names), 'Time (s)': round(processing_time, 2) } stats_text = "## Model Statistics\n\n| Metric | Value |\n|---|---|\n" for k,v in stats.items(): stats_text += f"| {k} | {v} |\n" while len(mvl_plot_paths) < 5: mvl_plot_paths.append(None) return [ None, cm_plot_path, *mvl_plot_paths[:5], integrated_mvl_plot_path, stats_text, '\n'.join(logs) ] paper_info_html = """
Mohammad Khalooei, Mohammad Mehdi Homayounpour, Maryam Amirmazlaghani