|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from torchvision.models import vgg16, vgg19, googlenet, resnet18 |
|
import timm |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from torchattacks import FGSM, PGD, APGD |
|
import os |
|
import time |
|
from datetime import datetime |
|
import gradio as gr |
|
|
|
class LeNet(nn.Module): |
|
def __init__(self): |
|
super(LeNet, self).__init__() |
|
self.conv1 = nn.Conv2d(1, 6, 5) |
|
self.conv2 = nn.Conv2d(6, 16, 5) |
|
self.fc1 = nn.Linear(16 * 4 * 4, 120) |
|
self.fc2 = nn.Linear(120, 84) |
|
self.fc3 = nn.Linear(84, 10) |
|
self.relu = nn.ReLU() |
|
self.pool = nn.MaxPool2d(2, 2) |
|
|
|
def forward(self, x, return_all=False): |
|
outputs = [] |
|
x1 = self.pool(self.relu(self.conv1(x))) |
|
outputs.append(x1) |
|
x2 = self.pool(self.relu(self.conv2(x1))) |
|
outputs.append(x2) |
|
x2_flat = x2.view(-1, 16 * 4 * 4) |
|
x3 = self.relu(self.fc1(x2_flat)) |
|
outputs.append(x3) |
|
x4 = self.relu(self.fc2(x3)) |
|
outputs.append(x4) |
|
x5 = self.fc3(x4) |
|
outputs.append(x5) |
|
if return_all: |
|
return outputs |
|
else: |
|
return x5 |
|
|
|
def salt_pepper_noise(images, prob=0.01, device='cuda'): |
|
batch_smap = torch.rand_like(images) < prob / 2 |
|
pepper = torch.rand_like(images) < prob / 2 |
|
noisy = images.clone() |
|
noisy[batch_smap] = 1.0 |
|
noisy[pepper] = 0.0 |
|
return torch.clamp(noisy, 0, 1) |
|
|
|
def pepper_statistical_noise(images, prob=0.01, device='cuda'): |
|
pepper = torch.rand_like(images) < prob |
|
noisy = images.clone() |
|
noisy[pepper] = 0.0 |
|
return torch.clamp(noisy, 0, 1) |
|
|
|
def get_layer_outputs(model, input_tensor): |
|
outputs = [] |
|
def hook(module, input, output): |
|
outputs.append(output) |
|
hooks = [] |
|
for layer in model.modules(): |
|
if isinstance(layer, (nn.Conv2d, nn.Linear)): |
|
hooks.append(layer.register_forward_hook(hook)) |
|
model.eval() |
|
with torch.no_grad(): |
|
model(input_tensor) |
|
for hook in hooks: |
|
hook.remove() |
|
return outputs |
|
|
|
def compute_mvl(model, clean_images, adv_images, device='cuda'): |
|
model.eval() |
|
with torch.no_grad(): |
|
try: |
|
clean_outputs = model(clean_images, return_all=True) |
|
adv_outputs = model(adv_images, return_all=True) |
|
except TypeError: |
|
clean_outputs = get_layer_outputs(model, clean_images) |
|
adv_outputs = get_layer_outputs(model, adv_images) |
|
|
|
mvl_list = [] |
|
for clean_out, adv_out in zip(clean_outputs, adv_outputs): |
|
if clean_out.ndim == 4: |
|
diff = torch.norm(clean_out - adv_out, p=2, dim=(1,2,3)) |
|
clean_norm = torch.norm(clean_out, p=2, dim=(1,2,3)) |
|
else: |
|
diff = torch.norm(clean_out - adv_out, p=2, dim=1) |
|
clean_norm = torch.norm(clean_out, p=2, dim=1) |
|
mvl = diff / (clean_norm + 1e-8) |
|
mvl_list.append(mvl.mean().item()) |
|
return mvl_list |
|
|
|
def get_model_stats(model): |
|
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
layer_count = len([m for m in model.modules() if isinstance(m, (nn.Conv2d, nn.Linear))]) |
|
return param_count, layer_count |
|
|
|
def modify_model(model, model_name): |
|
if model_name.startswith('VGG'): |
|
model.classifier[6] = nn.Linear(4096, 10) |
|
elif model_name == 'GoogLeNet': |
|
model.fc = nn.Linear(1024, 10) |
|
elif model_name == 'ResNet18': |
|
model.fc = nn.Linear(512, 10) |
|
elif model_name == 'WideResNet': |
|
model.fc = nn.Linear(2048, 10) |
|
elif model_name == 'DenseNet121': |
|
model.classifier = nn.Linear(model.classifier.in_features, 10) |
|
elif model_name == 'MobileNetV2': |
|
if isinstance(model.classifier, nn.Sequential): |
|
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10) |
|
else: |
|
model.classifier = nn.Linear(model.classifier.in_features, 10) |
|
elif model_name == 'EfficientNet-B0': |
|
model.classifier = nn.Linear(model.classifier.in_features, 10) |
|
return model |
|
|
|
def get_models_for_dataset(dataset_name): |
|
if dataset_name == 'MNIST': |
|
return ['LeNet'] |
|
elif dataset_name == 'CIFAR-10': |
|
return [ |
|
'VGG16', 'VGG19', 'GoogLeNet', 'ResNet18', 'WideResNet', |
|
'DenseNet121', 'MobileNetV2', 'EfficientNet-B0' |
|
] |
|
else: |
|
return [] |
|
|
|
def get_dataset_and_transform(dataset_name): |
|
if dataset_name == 'MNIST': |
|
transform = transforms.Compose([ |
|
transforms.Resize((28, 28)), |
|
transforms.Grayscale(num_output_channels=1), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.1307,), (0.3081,)) |
|
]) |
|
dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) |
|
else: |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), |
|
(0.229, 0.224, 0.225)) |
|
]) |
|
dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) |
|
return dataset, transform |
|
|
|
def initialize_model(model_name, device): |
|
if model_name == 'LeNet': |
|
model = LeNet() |
|
elif model_name == 'VGG16': |
|
model = modify_model(vgg16(weights='IMAGENET1K_V1'), model_name) |
|
elif model_name == 'VGG19': |
|
model = modify_model(vgg19(weights='IMAGENET1K_V1'), model_name) |
|
elif model_name == 'GoogLeNet': |
|
model = modify_model(googlenet(weights='IMAGENET1K_V1'), model_name) |
|
elif model_name == 'ResNet18': |
|
model = modify_model(resnet18(weights='IMAGENET1K_V1'), model_name) |
|
elif model_name == 'WideResNet': |
|
model = modify_model(timm.create_model('wide_resnet50_2', pretrained=True), model_name) |
|
elif model_name == 'DenseNet121': |
|
model = modify_model(timm.create_model('densenet121', pretrained=True), model_name) |
|
elif model_name == 'MobileNetV2': |
|
model = modify_model(timm.create_model('mobilenetv2_100', pretrained=True), model_name) |
|
elif model_name == 'EfficientNet-B0': |
|
model = modify_model(timm.create_model('efficientnet_b0', pretrained=True), model_name) |
|
else: |
|
raise ValueError(f"Unknown model {model_name}") |
|
return model.to(device) |
|
|
|
def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'): |
|
start_time = time.time() |
|
logs = ["BSM:: experiment is being started ..."] |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
logs.append(f"Loading {dataset_name} dataset...") |
|
dataset, _ = get_dataset_and_transform(dataset_name) |
|
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False) |
|
logs.append(f"{dataset_name} dataset loaded with {len(testloader)} batches.") |
|
|
|
logs.append(f"Initializing model {model_name} on {device}...") |
|
model = initialize_model(model_name, device) |
|
logs.append(f"Model {model_name} initialized.") |
|
|
|
param_count, layer_count = get_model_stats(model) |
|
logs.append(f"Model stats: Parameters = {param_count}, Layers = {layer_count}") |
|
|
|
all_attacks = { |
|
'FGSM': FGSM(model, eps=0.03), |
|
'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True), |
|
'APGD': APGD(model, eps=0.03, steps=100, loss='ce'), |
|
'Salt & Pepper': lambda x, y: salt_pepper_noise(x, prob=0.01, device=device), |
|
'Pepper Statistical': lambda x, y: pepper_statistical_noise(x, prob=0.01, device=device) |
|
} |
|
attacks = {name: attack for name, attack in all_attacks.items() if name in selected_attacks} |
|
if not attacks: |
|
logs.append("Error: No valid attacks selected") |
|
return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)] |
|
logs.append(f"Selected attacks: {', '.join(attacks.keys())}") |
|
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}") |
|
os.makedirs(output_dir, exist_ok=True) |
|
logs.append(f"Output directory created: {output_dir}") |
|
|
|
results = {atk: {'cm': [], 'mvl': []} for atk in attacks} |
|
|
|
for i, (images, labels) in enumerate(testloader): |
|
if i >= num_batches: |
|
logs.append(f"Reached batch limit: {num_batches}") |
|
break |
|
images, labels = images.to(device), labels.to(device) |
|
logs.append(f"Processing batch {i+1}/{num_batches}...") |
|
|
|
for atk_name, atk in attacks.items(): |
|
logs.append(f" Running attack: {atk_name} on batch {i+1}") |
|
adv_images = atk(images, labels) |
|
mvl_vals = compute_mvl(model, images, adv_images, device) |
|
results[atk_name]['mvl'].append(mvl_vals) |
|
batch_cm = np.mean(mvl_vals) |
|
results[atk_name]['cm'].append(batch_cm) |
|
logs.append(f" Attack {atk_name}: batch CM={batch_cm:.6f}") |
|
|
|
logs.append("Finished processing batches, computing statistics...") |
|
|
|
cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks} |
|
cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks} |
|
|
|
plt.figure(figsize=(8,6)) |
|
attack_names = list(attacks.keys()) |
|
means = [cm_means[a] for a in attack_names] |
|
stds = [cm_stds[a] for a in attack_names] |
|
x = np.arange(len(attack_names)) |
|
plt.bar(x, means, yerr=stds, capsize=5) |
|
plt.xticks(x, attack_names, rotation=45) |
|
plt.ylabel("CM (Relative Error)") |
|
plt.title(f"CM for {model_name} ({dataset_name})") |
|
plt.tight_layout() |
|
cm_plot_path = os.path.join(output_dir, "cm_plot.png") |
|
plt.savefig(cm_plot_path) |
|
plt.close() |
|
logs.append(f"Saved CM plot: {cm_plot_path}") |
|
|
|
mvl_plot_paths = [] |
|
colors = ['skyblue', 'lightgreen', 'coral', 'lightgray', 'purple'] |
|
for i, atk in enumerate(attack_names): |
|
mvl_arr = np.array(results[atk]['mvl']) |
|
mean_vals = np.mean(mvl_arr, axis=0) |
|
std_vals = np.std(mvl_arr, axis=0) |
|
layers = [f"Layer {j+1}" for j in range(len(mean_vals))] |
|
plt.figure(figsize=(8,6)) |
|
plt.plot(layers, mean_vals, marker='o', color=colors[i % len(colors)]) |
|
plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[i % len(colors)], alpha=0.3) |
|
plt.title(f"MVL per Layer - {atk}") |
|
plt.ylabel("MVL (Mean ± Std)") |
|
plt.xticks(rotation=45) |
|
plt.grid(True) |
|
plt.tight_layout() |
|
path = os.path.join(output_dir, f"mvl_{atk.lower().replace(' ', '_')}.png") |
|
plt.savefig(path) |
|
plt.close() |
|
mvl_plot_paths.append(path) |
|
logs.append(f"Saved MVL plot for {atk}: {path}") |
|
|
|
plt.figure(figsize=(10,6)) |
|
for i, atk in enumerate(attack_names): |
|
mvl_arr = np.array(results[atk]['mvl']) |
|
mean_vals = np.mean(mvl_arr, axis=0) |
|
std_vals = np.std(mvl_arr, axis=0) |
|
layers = [f"Layer {j+1}" for j in range(len(mean_vals))] |
|
plt.plot(layers, mean_vals, marker='o', color=colors[i % len(colors)], label=atk) |
|
plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[i % len(colors)], alpha=0.3) |
|
plt.title(f"Integrated MVL - {model_name}") |
|
plt.ylabel("MVL (Mean ± Std)") |
|
plt.xticks(rotation=45) |
|
plt.legend() |
|
plt.grid(True) |
|
plt.tight_layout() |
|
integrated_mvl_plot_path = os.path.join(output_dir, "integrated_mvl.png") |
|
plt.savefig(integrated_mvl_plot_path) |
|
plt.close() |
|
logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}") |
|
|
|
processing_time = time.time() - start_time |
|
logs.append(f"Processing completed in {processing_time:.2f} seconds") |
|
|
|
stats = { |
|
'Dataset': dataset_name, |
|
'Model': model_name, |
|
'Parameters': param_count, |
|
'Layers': layer_count, |
|
'Batches': num_batches, |
|
'Attacks': ', '.join(attack_names), |
|
'Time (s)': round(processing_time, 2) |
|
} |
|
stats_text = "## Model Statistics\n\n| Metric | Value |\n|---|---|\n" |
|
for k,v in stats.items(): |
|
stats_text += f"| {k} | {v} |\n" |
|
|
|
while len(mvl_plot_paths) < 5: |
|
mvl_plot_paths.append(None) |
|
|
|
return [ |
|
None, |
|
cm_plot_path, |
|
*mvl_plot_paths[:5], |
|
integrated_mvl_plot_path, |
|
stats_text, |
|
'\n'.join(logs) |
|
] |
|
|
|
paper_info_html = """ |
|
<div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;"> |
|
<h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2> |
|
<h3>Authors</h3> |
|
<p>Mohammad Khalooei, Mohammad Mehdi Homayounpour, Maryam Amirmazlaghani</p> |
|
|
|
<h3>Abstract</h3> |
|
<ul> |
|
<li>The layer sustainability analysis (LSA) framework is introduced to evaluate the behavior of layer-level representations of DNNs in dealing with network input perturbations using Lipschitz theoretical concepts.</li> |
|
<li>A layer-wise regularized adversarial training (AT-LR) approach significantly improves the generalization and robustness of different deep neural network architectures for significant perturbations while reducing layer-level vulnerabilities.</li> |
|
<li>AT-LR loss landscapes for each LSA MVL proposal can interpret layer importance for different layers, which is an intriguing aspect.</li> |
|
</ul> |
|
|
|
<h3>Links</h3> |
|
<ul> |
|
<li><a href="https://arxiv.org/abs/2202.02626" target="_blank">ArXiv Paper</a></li> |
|
<li><a href="https://github.com/khalooei/LSA" target="_blank">GitHub Repository</a></li> |
|
<li><a href="https://www.sciencedirect.com/science/article/abs/pii/S0925231223002928" target="_blank">ScienceDirect Article</a></li> |
|
</ul> |
|
</div> |
|
""" |
|
|
|
def update_models(dataset_name): |
|
if dataset_name == 'MNIST': |
|
return gr.update(visible=False), "LeNet" |
|
else: |
|
models = get_models_for_dataset(dataset_name) |
|
return gr.update(choices=models, value=models[0], visible=True), gr.update(visible=False) |
|
|
|
def create_interface(): |
|
datasets = ['MNIST', 'CIFAR-10'] |
|
attacks = ['FGSM', 'PGD', 'APGD', 'Salt & Pepper', 'Pepper Statistical'] |
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown("# Layer-wise Sustainability Analysis") |
|
gr.Markdown(paper_info_html) |
|
|
|
initial_input="MNIST" |
|
dataset_input = gr.Dropdown(datasets, label="Select Dataset", value=initial_input) |
|
model_input = gr.Dropdown(get_models_for_dataset(initial_input), value=get_models_for_dataset(initial_input)[0], label="Select Model") |
|
model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model") |
|
|
|
attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks) |
|
batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=2, label="Number of Batches") |
|
run_button = gr.Button("Run Analysis") |
|
|
|
error_output = gr.Textbox(label="Error", visible=False) |
|
cm_output = gr.Image(label="Comparative Measure (CM)") |
|
|
|
with gr.Tabs(): |
|
mvl_outputs = [] |
|
for attack in attacks: |
|
with gr.Tab(f"MVL: {attack}"): |
|
mvl_output = gr.Image(label=f"MVL for {attack}") |
|
mvl_outputs.append(mvl_output) |
|
with gr.Tab("Integrated MVL"): |
|
integrated_mvl_output = gr.Image(label="Integrated MVL for All Attacks") |
|
with gr.Tab("Model Statistics"): |
|
stats_output = gr.Markdown("## Model Statistics") |
|
with gr.Tab("Logs"): |
|
log_output = gr.Textbox(label="Processing Logs", lines=15, interactive=False) |
|
|
|
dataset_input.change( |
|
fn=update_models, |
|
inputs=dataset_input, |
|
outputs=[model_input, model_text] |
|
) |
|
|
|
def get_model_for_mnist_or_dropdown(dataset_name, model_name): |
|
return "LeNet" if dataset_name == 'MNIST' else model_name |
|
|
|
def run_analysis(dataset_name, model_name, attacks, batches): |
|
real_model = get_model_for_mnist_or_dropdown(dataset_name, model_name) |
|
return layer_sustainability_analysis(dataset_name, real_model, attacks, batches) |
|
|
|
run_button.click( |
|
fn=run_analysis, |
|
inputs=[dataset_input, model_input, attack_input, batch_input], |
|
outputs=[error_output, cm_output] + mvl_outputs + [integrated_mvl_output, stats_output, log_output] |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == '__main__': |
|
interface = create_interface() |
|
interface.launch() |
|
|