khalooei
commited on
Commit
·
4bd1b68
1
Parent(s):
7d45691
update app
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import time
|
|
12 |
from datetime import datetime
|
13 |
import gradio as gr
|
14 |
|
|
|
15 |
class LeNet(nn.Module):
|
16 |
def __init__(self):
|
17 |
super(LeNet, self).__init__()
|
@@ -41,6 +42,7 @@ class LeNet(nn.Module):
|
|
41 |
else:
|
42 |
return x5
|
43 |
|
|
|
44 |
def salt_pepper_noise(images, prob=0.01, device='cuda'):
|
45 |
batch_smap = torch.rand_like(images) < prob / 2
|
46 |
pepper = torch.rand_like(images) < prob / 2
|
@@ -55,6 +57,7 @@ def pepper_statistical_noise(images, prob=0.01, device='cuda'):
|
|
55 |
noisy[pepper] = 0.0
|
56 |
return torch.clamp(noisy, 0, 1)
|
57 |
|
|
|
58 |
def get_layer_outputs(model, input_tensor):
|
59 |
outputs = []
|
60 |
def hook(module, input, output):
|
@@ -128,6 +131,7 @@ def get_models_for_dataset(dataset_name):
|
|
128 |
else:
|
129 |
return []
|
130 |
|
|
|
131 |
def get_dataset_and_transform(dataset_name):
|
132 |
if dataset_name == 'MNIST':
|
133 |
transform = transforms.Compose([
|
@@ -175,16 +179,20 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
175 |
logs = []
|
176 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
177 |
|
|
|
178 |
dataset, transform = get_dataset_and_transform(dataset_name)
|
179 |
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
|
180 |
logs.append(f"{dataset_name} dataset loaded")
|
181 |
|
|
|
182 |
model = initialize_model(model_name, device)
|
183 |
logs.append(f"Model {model_name} loaded on {device}")
|
184 |
|
|
|
185 |
param_count, layer_count = get_model_stats(model)
|
186 |
logs.append(f"Model stats: {param_count} parameters, {layer_count} layers")
|
187 |
|
|
|
188 |
all_attacks = {
|
189 |
'FGSM': FGSM(model, eps=0.03),
|
190 |
'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True),
|
@@ -198,41 +206,109 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
198 |
return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)]
|
199 |
logs.append(f"Selected attacks: {', '.join(attacks.keys())}")
|
200 |
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
|
|
203 |
for i, (images, labels) in enumerate(testloader):
|
204 |
if i >= num_batches:
|
205 |
break
|
206 |
images, labels = images.to(device), labels.to(device)
|
207 |
logs.append(f"Processing batch {i+1}/{num_batches}...")
|
208 |
|
209 |
-
for
|
210 |
-
adv_images =
|
211 |
-
|
212 |
-
results[
|
213 |
-
cm
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
processing_time = time.time() - start_time
|
222 |
|
223 |
stats = {
|
224 |
'Dataset': dataset_name,
|
225 |
'Model': model_name,
|
226 |
-
'
|
227 |
-
'
|
228 |
-
'
|
229 |
-
'
|
230 |
-
'
|
231 |
}
|
232 |
-
stats_text = "## Model Statistics\n\n| Metric | Value |\n
|
233 |
for k,v in stats.items():
|
234 |
stats_text += f"| {k} | {v} |\n"
|
235 |
|
|
|
|
|
|
|
|
|
236 |
return [None, cm_plot_path] + mvl_plot_paths[:5] + [integrated_mvl_plot_path, stats_text, '\n'.join(logs)]
|
237 |
|
238 |
paper_info_html = """
|
@@ -260,7 +336,7 @@ paper_info_html = """
|
|
260 |
def update_models(dataset_name):
|
261 |
models = get_models_for_dataset(dataset_name)
|
262 |
default_value = models[0] if models else None
|
263 |
-
return models, default_value # Return choices and default value
|
264 |
|
265 |
def create_interface():
|
266 |
datasets = ['MNIST', 'CIFAR-10']
|
@@ -292,11 +368,10 @@ def create_interface():
|
|
292 |
with gr.Tab("Logs"):
|
293 |
log_output = gr.Textbox(label="Processing Logs")
|
294 |
|
295 |
-
# Return choices and value separately for older gradio versions
|
296 |
dataset_input.change(
|
297 |
fn=update_models,
|
298 |
inputs=dataset_input,
|
299 |
-
outputs=[model_input, model_input]
|
300 |
)
|
301 |
|
302 |
run_button.click(
|
|
|
12 |
from datetime import datetime
|
13 |
import gradio as gr
|
14 |
|
15 |
+
# LeNet for MNIST
|
16 |
class LeNet(nn.Module):
|
17 |
def __init__(self):
|
18 |
super(LeNet, self).__init__()
|
|
|
42 |
else:
|
43 |
return x5
|
44 |
|
45 |
+
# Noise functions
|
46 |
def salt_pepper_noise(images, prob=0.01, device='cuda'):
|
47 |
batch_smap = torch.rand_like(images) < prob / 2
|
48 |
pepper = torch.rand_like(images) < prob / 2
|
|
|
57 |
noisy[pepper] = 0.0
|
58 |
return torch.clamp(noisy, 0, 1)
|
59 |
|
60 |
+
# MVL calculation with hooks fallback
|
61 |
def get_layer_outputs(model, input_tensor):
|
62 |
outputs = []
|
63 |
def hook(module, input, output):
|
|
|
131 |
else:
|
132 |
return []
|
133 |
|
134 |
+
|
135 |
def get_dataset_and_transform(dataset_name):
|
136 |
if dataset_name == 'MNIST':
|
137 |
transform = transforms.Compose([
|
|
|
179 |
logs = []
|
180 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
181 |
|
182 |
+
# Prepare dataset & loader
|
183 |
dataset, transform = get_dataset_and_transform(dataset_name)
|
184 |
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
|
185 |
logs.append(f"{dataset_name} dataset loaded")
|
186 |
|
187 |
+
# Init model
|
188 |
model = initialize_model(model_name, device)
|
189 |
logs.append(f"Model {model_name} loaded on {device}")
|
190 |
|
191 |
+
# Model stats
|
192 |
param_count, layer_count = get_model_stats(model)
|
193 |
logs.append(f"Model stats: {param_count} parameters, {layer_count} layers")
|
194 |
|
195 |
+
# Setup attacks
|
196 |
all_attacks = {
|
197 |
'FGSM': FGSM(model, eps=0.03),
|
198 |
'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True),
|
|
|
206 |
return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)]
|
207 |
logs.append(f"Selected attacks: {', '.join(attacks.keys())}")
|
208 |
|
209 |
+
# Prepare output dir for plots
|
210 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
211 |
+
output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
|
212 |
+
os.makedirs(output_dir, exist_ok=True)
|
213 |
+
logs.append(f"Output directory: {output_dir}")
|
214 |
+
|
215 |
+
# Collect results
|
216 |
+
results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
|
217 |
|
218 |
+
# Process batches
|
219 |
for i, (images, labels) in enumerate(testloader):
|
220 |
if i >= num_batches:
|
221 |
break
|
222 |
images, labels = images.to(device), labels.to(device)
|
223 |
logs.append(f"Processing batch {i+1}/{num_batches}...")
|
224 |
|
225 |
+
for atk_name, atk in attacks.items():
|
226 |
+
adv_images = atk(images, labels)
|
227 |
+
mvl_vals = compute_mvl(model, images, adv_images, device)
|
228 |
+
results[atk_name]['mvl'].append(mvl_vals)
|
229 |
+
results[atk_name]['cm'].append(np.mean(mvl_vals))
|
230 |
+
|
231 |
+
# Compute mean/std CM per attack
|
232 |
+
cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
|
233 |
+
cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
|
234 |
+
|
235 |
+
# Plot CM bar
|
236 |
+
plt.figure(figsize=(8,6))
|
237 |
+
names = list(attacks.keys())
|
238 |
+
means = [cm_means[n] for n in names]
|
239 |
+
stds = [cm_stds[n] for n in names]
|
240 |
+
x = np.arange(len(names))
|
241 |
+
plt.bar(x, means, yerr=stds, capsize=5)
|
242 |
+
plt.xticks(x, names, rotation=45)
|
243 |
+
plt.ylabel("CM (Relative Error)")
|
244 |
+
plt.title(f"CM for {model_name} ({dataset_name})")
|
245 |
+
plt.tight_layout()
|
246 |
+
cm_plot_path = os.path.join(output_dir, "cm_plot.png")
|
247 |
+
plt.savefig(cm_plot_path)
|
248 |
+
plt.close()
|
249 |
+
logs.append(f"Saved CM plot to {cm_plot_path}")
|
250 |
+
|
251 |
+
# Plot MVL per attack
|
252 |
+
mvl_plot_paths = []
|
253 |
+
colors = ['skyblue', 'lightgreen', 'coral', 'lightgray', 'purple']
|
254 |
+
for idx, atk in enumerate(names):
|
255 |
+
mvl_arr = np.array(results[atk]['mvl'])
|
256 |
+
mean_vals = np.mean(mvl_arr, axis=0)
|
257 |
+
std_vals = np.std(mvl_arr, axis=0)
|
258 |
+
layers = [f"Layer {i+1}" for i in range(len(mean_vals))]
|
259 |
+
plt.figure(figsize=(8,6))
|
260 |
+
plt.plot(layers, mean_vals, marker='o', color=colors[idx % len(colors)], label=atk)
|
261 |
+
plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[idx % len(colors)], alpha=0.3)
|
262 |
+
plt.title(f"MVL per Layer - {atk}")
|
263 |
+
plt.ylabel("MVL (Mean ± Std)")
|
264 |
+
plt.xticks(rotation=45)
|
265 |
+
plt.grid(True)
|
266 |
+
plt.tight_layout()
|
267 |
+
path = os.path.join(output_dir, f"mvl_{atk.lower().replace(' ', '_')}.png")
|
268 |
+
plt.savefig(path)
|
269 |
+
plt.close()
|
270 |
+
mvl_plot_paths.append(path)
|
271 |
+
logs.append(f"Saved MVL plot for {atk} to {path}")
|
272 |
+
|
273 |
+
# Integrated MVL plot
|
274 |
+
plt.figure(figsize=(10,6))
|
275 |
+
for idx, atk in enumerate(names):
|
276 |
+
mvl_arr = np.array(results[atk]['mvl'])
|
277 |
+
mean_vals = np.mean(mvl_arr, axis=0)
|
278 |
+
std_vals = np.std(mvl_arr, axis=0)
|
279 |
+
layers = [f"Layer {i+1}" for i in range(len(mean_vals))]
|
280 |
+
plt.plot(layers, mean_vals, marker='o', color=colors[idx % len(colors)], label=atk)
|
281 |
+
plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[idx % len(colors)], alpha=0.3)
|
282 |
+
plt.title(f"Integrated MVL - {model_name}")
|
283 |
+
plt.ylabel("MVL (Mean ± Std)")
|
284 |
+
plt.xticks(rotation=45)
|
285 |
+
plt.legend()
|
286 |
+
plt.grid(True)
|
287 |
+
plt.tight_layout()
|
288 |
+
integrated_mvl_plot_path = os.path.join(output_dir, "integrated_mvl.png")
|
289 |
+
plt.savefig(integrated_mvl_plot_path)
|
290 |
+
plt.close()
|
291 |
+
logs.append(f"Saved integrated MVL plot to {integrated_mvl_plot_path}")
|
292 |
|
293 |
processing_time = time.time() - start_time
|
294 |
|
295 |
stats = {
|
296 |
'Dataset': dataset_name,
|
297 |
'Model': model_name,
|
298 |
+
'Parameters': param_count,
|
299 |
+
'Layers': layer_count,
|
300 |
+
'Batches': num_batches,
|
301 |
+
'Attacks': ', '.join(names),
|
302 |
+
'Time (s)': round(processing_time, 2)
|
303 |
}
|
304 |
+
stats_text = "## Model Statistics\n\n| Metric | Value |\n|---|---|\n"
|
305 |
for k,v in stats.items():
|
306 |
stats_text += f"| {k} | {v} |\n"
|
307 |
|
308 |
+
# Pad MVL plot paths to length 5 for UI consistency
|
309 |
+
while len(mvl_plot_paths) < 5:
|
310 |
+
mvl_plot_paths.append(None)
|
311 |
+
|
312 |
return [None, cm_plot_path] + mvl_plot_paths[:5] + [integrated_mvl_plot_path, stats_text, '\n'.join(logs)]
|
313 |
|
314 |
paper_info_html = """
|
|
|
336 |
def update_models(dataset_name):
|
337 |
models = get_models_for_dataset(dataset_name)
|
338 |
default_value = models[0] if models else None
|
339 |
+
return models, default_value # Return choices and default value for older gradio versions
|
340 |
|
341 |
def create_interface():
|
342 |
datasets = ['MNIST', 'CIFAR-10']
|
|
|
368 |
with gr.Tab("Logs"):
|
369 |
log_output = gr.Textbox(label="Processing Logs")
|
370 |
|
|
|
371 |
dataset_input.change(
|
372 |
fn=update_models,
|
373 |
inputs=dataset_input,
|
374 |
+
outputs=[model_input, model_input] # updates choices and default value
|
375 |
)
|
376 |
|
377 |
run_button.click(
|