lsa / app.py
khalooei
update app
b2419d7
# Developed by Mohammad Khalooei
# More information and contact: https://github.com/khalooei/LSA
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import vgg16, vgg19, googlenet, resnet18
import timm
import numpy as np
import matplotlib.pyplot as plt
from torchattacks import FGSM, PGD, APGD
import os
import time
from datetime import datetime
import gradio as gr
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x, return_all=False):
outputs = []
x1 = self.pool(self.relu(self.conv1(x)))
outputs.append(x1)
x2 = self.pool(self.relu(self.conv2(x1)))
outputs.append(x2)
x2_flat = x2.view(-1, 16 * 4 * 4)
x3 = self.relu(self.fc1(x2_flat))
outputs.append(x3)
x4 = self.relu(self.fc2(x3))
outputs.append(x4)
x5 = self.fc3(x4)
outputs.append(x5)
if return_all:
return outputs
else:
return x5
def salt_pepper_noise(images, prob=0.01, device='cuda'):
batch_smap = torch.rand_like(images) < prob / 2
pepper = torch.rand_like(images) < prob / 2
noisy = images.clone()
noisy[batch_smap] = 1.0
noisy[pepper] = 0.0
return torch.clamp(noisy, 0, 1)
def pepper_statistical_noise(images, prob=0.01, device='cuda'):
pepper = torch.rand_like(images) < prob
noisy = images.clone()
noisy[pepper] = 0.0
return torch.clamp(noisy, 0, 1)
def get_layer_outputs(model, input_tensor):
outputs = []
def hook(module, input, output):
outputs.append(output)
hooks = []
for layer in model.modules():
if isinstance(layer, (nn.Conv2d, nn.Linear)):
hooks.append(layer.register_forward_hook(hook))
model.eval()
with torch.no_grad():
model(input_tensor)
for hook in hooks:
hook.remove()
return outputs
def compute_mvl(model, clean_images, adv_images, device='cuda'):
model.eval()
with torch.no_grad():
try:
clean_outputs = model(clean_images, return_all=True)
adv_outputs = model(adv_images, return_all=True)
except TypeError:
clean_outputs = get_layer_outputs(model, clean_images)
adv_outputs = get_layer_outputs(model, adv_images)
mvl_list = []
for clean_out, adv_out in zip(clean_outputs, adv_outputs):
if clean_out.ndim == 4:
diff = torch.norm(clean_out - adv_out, p=2, dim=(1,2,3))
clean_norm = torch.norm(clean_out, p=2, dim=(1,2,3))
else:
diff = torch.norm(clean_out - adv_out, p=2, dim=1)
clean_norm = torch.norm(clean_out, p=2, dim=1)
mvl = diff / (clean_norm + 1e-8)
mvl_list.append(mvl.mean().item())
return mvl_list
def get_model_stats(model):
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
layer_count = len([m for m in model.modules() if isinstance(m, (nn.Conv2d, nn.Linear))])
return param_count, layer_count
def modify_model(model, model_name):
if model_name.startswith('VGG'):
model.classifier[6] = nn.Linear(4096, 10)
elif model_name == 'GoogLeNet':
model.fc = nn.Linear(1024, 10)
elif model_name == 'ResNet18':
model.fc = nn.Linear(512, 10)
elif model_name == 'WideResNet':
model.fc = nn.Linear(2048, 10)
elif model_name == 'DenseNet121':
model.classifier = nn.Linear(model.classifier.in_features, 10)
elif model_name == 'MobileNetV2':
if isinstance(model.classifier, nn.Sequential):
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)
else:
model.classifier = nn.Linear(model.classifier.in_features, 10)
elif model_name == 'EfficientNet-B0':
model.classifier = nn.Linear(model.classifier.in_features, 10)
return model
def get_models_for_dataset(dataset_name):
if dataset_name == 'MNIST':
return ['LeNet']
elif dataset_name == 'CIFAR-10':
return [
'VGG16', 'VGG19', 'GoogLeNet', 'ResNet18', 'WideResNet',
'DenseNet121', 'MobileNetV2', 'EfficientNet-B0'
]
else:
return []
def get_dataset_and_transform(dataset_name):
if dataset_name == 'MNIST':
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
else:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
return dataset, transform
def initialize_model(model_name, device):
if model_name == 'LeNet':
model = LeNet()
elif model_name == 'VGG16':
model = modify_model(vgg16(weights='IMAGENET1K_V1'), model_name)
elif model_name == 'VGG19':
model = modify_model(vgg19(weights='IMAGENET1K_V1'), model_name)
elif model_name == 'GoogLeNet':
model = modify_model(googlenet(weights='IMAGENET1K_V1'), model_name)
elif model_name == 'ResNet18':
model = modify_model(resnet18(weights='IMAGENET1K_V1'), model_name)
elif model_name == 'WideResNet':
model = modify_model(timm.create_model('wide_resnet50_2', pretrained=True), model_name)
elif model_name == 'DenseNet121':
model = modify_model(timm.create_model('densenet121', pretrained=True), model_name)
elif model_name == 'MobileNetV2':
model = modify_model(timm.create_model('mobilenetv2_100', pretrained=True), model_name)
elif model_name == 'EfficientNet-B0':
model = modify_model(timm.create_model('efficientnet_b0', pretrained=True), model_name)
else:
raise ValueError(f"Unknown model {model_name}")
return model.to(device)
def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'):
start_time = time.time()
logs = ["BSM:: experiment is being started ..."]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logs.append(f"Loading {dataset_name} dataset...")
dataset, _ = get_dataset_and_transform(dataset_name)
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
logs.append(f"{dataset_name} dataset loaded with {len(testloader)} batches.")
logs.append(f"Initializing model {model_name} on {device}...")
model = initialize_model(model_name, device)
logs.append(f"Model {model_name} initialized.")
param_count, layer_count = get_model_stats(model)
logs.append(f"Model stats: Parameters = {param_count}, Layers = {layer_count}")
all_attacks = {
'FGSM': FGSM(model, eps=0.03),
'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True),
'APGD': APGD(model, eps=0.03, steps=100, loss='ce'),
'Salt & Pepper': lambda x, y: salt_pepper_noise(x, prob=0.01, device=device),
'Pepper Statistical': lambda x, y: pepper_statistical_noise(x, prob=0.01, device=device)
}
attacks = {name: attack for name, attack in all_attacks.items() if name in selected_attacks}
if not attacks:
logs.append("Error: No valid attacks selected")
return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)]
logs.append(f"Selected attacks: {', '.join(attacks.keys())}")
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
logs.append(f"Output directory created: {output_dir}")
results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
for i, (images, labels) in enumerate(testloader):
if i >= num_batches:
logs.append(f"Reached batch limit: {num_batches}")
break
images, labels = images.to(device), labels.to(device)
logs.append(f"Processing batch {i+1}/{num_batches}...")
for atk_name, atk in attacks.items():
logs.append(f" Running attack: {atk_name} on batch {i+1}")
adv_images = atk(images, labels)
mvl_vals = compute_mvl(model, images, adv_images, device)
results[atk_name]['mvl'].append(mvl_vals)
batch_cm = np.mean(mvl_vals)
results[atk_name]['cm'].append(batch_cm)
logs.append(f" Attack {atk_name}: batch CM={batch_cm:.6f}")
logs.append("Finished processing batches, computing statistics...")
cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
plt.figure(figsize=(8,6))
attack_names = list(attacks.keys())
means = [cm_means[a] for a in attack_names]
stds = [cm_stds[a] for a in attack_names]
x = np.arange(len(attack_names))
plt.bar(x, means, yerr=stds, capsize=5)
plt.xticks(x, attack_names, rotation=45)
plt.ylabel("CM (Relative Error)")
plt.title(f"CM for {model_name} ({dataset_name})")
plt.tight_layout()
cm_plot_path = os.path.join(output_dir, "cm_plot.png")
plt.savefig(cm_plot_path)
plt.close()
logs.append(f"Saved CM plot: {cm_plot_path}")
mvl_plot_paths = []
colors = ['skyblue', 'lightgreen', 'coral', 'lightgray', 'purple']
for i, atk in enumerate(attack_names):
mvl_arr = np.array(results[atk]['mvl'])
mean_vals = np.mean(mvl_arr, axis=0)
std_vals = np.std(mvl_arr, axis=0)
layers = [f"Layer {j+1}" for j in range(len(mean_vals))]
plt.figure(figsize=(8,6))
plt.plot(layers, mean_vals, marker='o', color=colors[i % len(colors)])
plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[i % len(colors)], alpha=0.3)
plt.title(f"MVL per Layer - {atk}")
plt.ylabel("MVL (Mean ± Std)")
plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()
path = os.path.join(output_dir, f"mvl_{atk.lower().replace(' ', '_')}.png")
plt.savefig(path)
plt.close()
mvl_plot_paths.append(path)
logs.append(f"Saved MVL plot for {atk}: {path}")
plt.figure(figsize=(10,6))
for i, atk in enumerate(attack_names):
mvl_arr = np.array(results[atk]['mvl'])
mean_vals = np.mean(mvl_arr, axis=0)
std_vals = np.std(mvl_arr, axis=0)
layers = [f"Layer {j+1}" for j in range(len(mean_vals))]
plt.plot(layers, mean_vals, marker='o', color=colors[i % len(colors)], label=atk)
plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[i % len(colors)], alpha=0.3)
plt.title(f"Integrated MVL - {model_name}")
plt.ylabel("MVL (Mean ± Std)")
plt.xticks(rotation=45)
plt.legend()
plt.grid(True)
plt.tight_layout()
integrated_mvl_plot_path = os.path.join(output_dir, "integrated_mvl.png")
plt.savefig(integrated_mvl_plot_path)
plt.close()
logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}")
processing_time = time.time() - start_time
logs.append(f"Processing completed in {processing_time:.2f} seconds")
stats = {
'Dataset': dataset_name,
'Model': model_name,
'Parameters': param_count,
'Layers': layer_count,
'Batches': num_batches,
'Attacks': ', '.join(attack_names),
'Time (s)': round(processing_time, 2)
}
stats_text = "## Model Statistics\n\n| Metric | Value |\n|---|---|\n"
for k,v in stats.items():
stats_text += f"| {k} | {v} |\n"
while len(mvl_plot_paths) < 5:
mvl_plot_paths.append(None)
return [
None,
cm_plot_path,
*mvl_plot_paths[:5],
integrated_mvl_plot_path,
stats_text,
'\n'.join(logs)
]
paper_info_html = """
<div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
<h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2>
<h3>Authors</h3>
<p>Mohammad Khalooei, Mohammad Mehdi Homayounpour, Maryam Amirmazlaghani</p>
<h3>Abstract</h3>
<ul>
<li>The layer sustainability analysis (LSA) framework is introduced to evaluate the behavior of layer-level representations of DNNs in dealing with network input perturbations using Lipschitz theoretical concepts.</li>
<li>A layer-wise regularized adversarial training (AT-LR) approach significantly improves the generalization and robustness of different deep neural network architectures for significant perturbations while reducing layer-level vulnerabilities.</li>
<li>AT-LR loss landscapes for each LSA MVL proposal can interpret layer importance for different layers, which is an intriguing aspect.</li>
</ul>
<h3>Links</h3>
<ul>
<li><a href="https://arxiv.org/abs/2202.02626" target="_blank">ArXiv Paper</a></li>
<li><a href="https://github.com/khalooei/LSA" target="_blank">GitHub Repository</a></li>
<li><a href="https://www.sciencedirect.com/science/article/abs/pii/S0925231223002928" target="_blank">ScienceDirect Article</a></li>
</ul>
</div>
"""
def update_models(dataset_name):
if dataset_name == 'MNIST':
return gr.update(visible=False), "LeNet"
else:
models = get_models_for_dataset(dataset_name)
return gr.update(choices=models, value=models[0], visible=True), gr.update(visible=False)
def create_interface():
datasets = ['MNIST', 'CIFAR-10']
attacks = ['FGSM', 'PGD', 'APGD', 'Salt & Pepper', 'Pepper Statistical']
with gr.Blocks() as interface:
gr.Markdown("# Layer-wise Sustainability Analysis")
gr.Markdown(paper_info_html)
initial_input="MNIST"
dataset_input = gr.Dropdown(datasets, label="Select Dataset", value=initial_input)
model_input = gr.Dropdown(get_models_for_dataset(initial_input), value=get_models_for_dataset(initial_input)[0], label="Select Model")
model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=2, label="Number of Batches")
run_button = gr.Button("Run Analysis")
error_output = gr.Textbox(label="Error", visible=False)
cm_output = gr.Image(label="Comparative Measure (CM)")
with gr.Tabs():
mvl_outputs = []
for attack in attacks:
with gr.Tab(f"MVL: {attack}"):
mvl_output = gr.Image(label=f"MVL for {attack}")
mvl_outputs.append(mvl_output)
with gr.Tab("Integrated MVL"):
integrated_mvl_output = gr.Image(label="Integrated MVL for All Attacks")
with gr.Tab("Model Statistics"):
stats_output = gr.Markdown("## Model Statistics")
with gr.Tab("Logs"):
log_output = gr.Textbox(label="Processing Logs", lines=15, interactive=False)
dataset_input.change(
fn=update_models,
inputs=dataset_input,
outputs=[model_input, model_text]
)
def get_model_for_mnist_or_dropdown(dataset_name, model_name):
return "LeNet" if dataset_name == 'MNIST' else model_name
def run_analysis(dataset_name, model_name, attacks, batches):
real_model = get_model_for_mnist_or_dropdown(dataset_name, model_name)
return layer_sustainability_analysis(dataset_name, real_model, attacks, batches)
run_button.click(
fn=run_analysis,
inputs=[dataset_input, model_input, attack_input, batch_input],
outputs=[error_output, cm_output] + mvl_outputs + [integrated_mvl_output, stats_output, log_output]
)
return interface
if __name__ == '__main__':
interface = create_interface()
interface.launch()