khalooei commited on
Commit
4bd1b68
·
1 Parent(s): 7d45691

update app

Browse files
Files changed (1) hide show
  1. app.py +96 -21
app.py CHANGED
@@ -12,6 +12,7 @@ import time
12
  from datetime import datetime
13
  import gradio as gr
14
 
 
15
  class LeNet(nn.Module):
16
  def __init__(self):
17
  super(LeNet, self).__init__()
@@ -41,6 +42,7 @@ class LeNet(nn.Module):
41
  else:
42
  return x5
43
 
 
44
  def salt_pepper_noise(images, prob=0.01, device='cuda'):
45
  batch_smap = torch.rand_like(images) < prob / 2
46
  pepper = torch.rand_like(images) < prob / 2
@@ -55,6 +57,7 @@ def pepper_statistical_noise(images, prob=0.01, device='cuda'):
55
  noisy[pepper] = 0.0
56
  return torch.clamp(noisy, 0, 1)
57
 
 
58
  def get_layer_outputs(model, input_tensor):
59
  outputs = []
60
  def hook(module, input, output):
@@ -128,6 +131,7 @@ def get_models_for_dataset(dataset_name):
128
  else:
129
  return []
130
 
 
131
  def get_dataset_and_transform(dataset_name):
132
  if dataset_name == 'MNIST':
133
  transform = transforms.Compose([
@@ -175,16 +179,20 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
175
  logs = []
176
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
177
 
 
178
  dataset, transform = get_dataset_and_transform(dataset_name)
179
  testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
180
  logs.append(f"{dataset_name} dataset loaded")
181
 
 
182
  model = initialize_model(model_name, device)
183
  logs.append(f"Model {model_name} loaded on {device}")
184
 
 
185
  param_count, layer_count = get_model_stats(model)
186
  logs.append(f"Model stats: {param_count} parameters, {layer_count} layers")
187
 
 
188
  all_attacks = {
189
  'FGSM': FGSM(model, eps=0.03),
190
  'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True),
@@ -198,41 +206,109 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
198
  return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)]
199
  logs.append(f"Selected attacks: {', '.join(attacks.keys())}")
200
 
201
- results = {attack_name: {'cm': [], 'mvl': []} for attack_name in attacks}
 
 
 
 
 
 
 
202
 
 
203
  for i, (images, labels) in enumerate(testloader):
204
  if i >= num_batches:
205
  break
206
  images, labels = images.to(device), labels.to(device)
207
  logs.append(f"Processing batch {i+1}/{num_batches}...")
208
 
209
- for attack_name, attack in attacks.items():
210
- adv_images = attack(images, labels)
211
- mvl_list = compute_mvl(model, images, adv_images, device)
212
- results[attack_name]['mvl'].append(mvl_list)
213
- cm = np.mean(mvl_list)
214
- results[attack_name]['cm'].append(cm)
215
-
216
- # Placeholders for plots (add your plot generation here)
217
- cm_plot_path = None
218
- mvl_plot_paths = [None]*5
219
- integrated_mvl_plot_path = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  processing_time = time.time() - start_time
222
 
223
  stats = {
224
  'Dataset': dataset_name,
225
  'Model': model_name,
226
- 'Parameter Count': param_count,
227
- 'Layer Count': layer_count,
228
- 'Processing Time (s)': round(processing_time, 2),
229
- 'Number of Batches': num_batches,
230
- 'Attacks Used': ', '.join(attacks.keys())
231
  }
232
- stats_text = "## Model Statistics\n\n| Metric | Value |\n|--------|-------|\n"
233
  for k,v in stats.items():
234
  stats_text += f"| {k} | {v} |\n"
235
 
 
 
 
 
236
  return [None, cm_plot_path] + mvl_plot_paths[:5] + [integrated_mvl_plot_path, stats_text, '\n'.join(logs)]
237
 
238
  paper_info_html = """
@@ -260,7 +336,7 @@ paper_info_html = """
260
  def update_models(dataset_name):
261
  models = get_models_for_dataset(dataset_name)
262
  default_value = models[0] if models else None
263
- return models, default_value # Return choices and default value as a tuple
264
 
265
  def create_interface():
266
  datasets = ['MNIST', 'CIFAR-10']
@@ -292,11 +368,10 @@ def create_interface():
292
  with gr.Tab("Logs"):
293
  log_output = gr.Textbox(label="Processing Logs")
294
 
295
- # Return choices and value separately for older gradio versions
296
  dataset_input.change(
297
  fn=update_models,
298
  inputs=dataset_input,
299
- outputs=[model_input, model_input]
300
  )
301
 
302
  run_button.click(
 
12
  from datetime import datetime
13
  import gradio as gr
14
 
15
+ # LeNet for MNIST
16
  class LeNet(nn.Module):
17
  def __init__(self):
18
  super(LeNet, self).__init__()
 
42
  else:
43
  return x5
44
 
45
+ # Noise functions
46
  def salt_pepper_noise(images, prob=0.01, device='cuda'):
47
  batch_smap = torch.rand_like(images) < prob / 2
48
  pepper = torch.rand_like(images) < prob / 2
 
57
  noisy[pepper] = 0.0
58
  return torch.clamp(noisy, 0, 1)
59
 
60
+ # MVL calculation with hooks fallback
61
  def get_layer_outputs(model, input_tensor):
62
  outputs = []
63
  def hook(module, input, output):
 
131
  else:
132
  return []
133
 
134
+
135
  def get_dataset_and_transform(dataset_name):
136
  if dataset_name == 'MNIST':
137
  transform = transforms.Compose([
 
179
  logs = []
180
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
181
 
182
+ # Prepare dataset & loader
183
  dataset, transform = get_dataset_and_transform(dataset_name)
184
  testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
185
  logs.append(f"{dataset_name} dataset loaded")
186
 
187
+ # Init model
188
  model = initialize_model(model_name, device)
189
  logs.append(f"Model {model_name} loaded on {device}")
190
 
191
+ # Model stats
192
  param_count, layer_count = get_model_stats(model)
193
  logs.append(f"Model stats: {param_count} parameters, {layer_count} layers")
194
 
195
+ # Setup attacks
196
  all_attacks = {
197
  'FGSM': FGSM(model, eps=0.03),
198
  'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True),
 
206
  return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)]
207
  logs.append(f"Selected attacks: {', '.join(attacks.keys())}")
208
 
209
+ # Prepare output dir for plots
210
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
211
+ output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
212
+ os.makedirs(output_dir, exist_ok=True)
213
+ logs.append(f"Output directory: {output_dir}")
214
+
215
+ # Collect results
216
+ results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
217
 
218
+ # Process batches
219
  for i, (images, labels) in enumerate(testloader):
220
  if i >= num_batches:
221
  break
222
  images, labels = images.to(device), labels.to(device)
223
  logs.append(f"Processing batch {i+1}/{num_batches}...")
224
 
225
+ for atk_name, atk in attacks.items():
226
+ adv_images = atk(images, labels)
227
+ mvl_vals = compute_mvl(model, images, adv_images, device)
228
+ results[atk_name]['mvl'].append(mvl_vals)
229
+ results[atk_name]['cm'].append(np.mean(mvl_vals))
230
+
231
+ # Compute mean/std CM per attack
232
+ cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
233
+ cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
234
+
235
+ # Plot CM bar
236
+ plt.figure(figsize=(8,6))
237
+ names = list(attacks.keys())
238
+ means = [cm_means[n] for n in names]
239
+ stds = [cm_stds[n] for n in names]
240
+ x = np.arange(len(names))
241
+ plt.bar(x, means, yerr=stds, capsize=5)
242
+ plt.xticks(x, names, rotation=45)
243
+ plt.ylabel("CM (Relative Error)")
244
+ plt.title(f"CM for {model_name} ({dataset_name})")
245
+ plt.tight_layout()
246
+ cm_plot_path = os.path.join(output_dir, "cm_plot.png")
247
+ plt.savefig(cm_plot_path)
248
+ plt.close()
249
+ logs.append(f"Saved CM plot to {cm_plot_path}")
250
+
251
+ # Plot MVL per attack
252
+ mvl_plot_paths = []
253
+ colors = ['skyblue', 'lightgreen', 'coral', 'lightgray', 'purple']
254
+ for idx, atk in enumerate(names):
255
+ mvl_arr = np.array(results[atk]['mvl'])
256
+ mean_vals = np.mean(mvl_arr, axis=0)
257
+ std_vals = np.std(mvl_arr, axis=0)
258
+ layers = [f"Layer {i+1}" for i in range(len(mean_vals))]
259
+ plt.figure(figsize=(8,6))
260
+ plt.plot(layers, mean_vals, marker='o', color=colors[idx % len(colors)], label=atk)
261
+ plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[idx % len(colors)], alpha=0.3)
262
+ plt.title(f"MVL per Layer - {atk}")
263
+ plt.ylabel("MVL (Mean ± Std)")
264
+ plt.xticks(rotation=45)
265
+ plt.grid(True)
266
+ plt.tight_layout()
267
+ path = os.path.join(output_dir, f"mvl_{atk.lower().replace(' ', '_')}.png")
268
+ plt.savefig(path)
269
+ plt.close()
270
+ mvl_plot_paths.append(path)
271
+ logs.append(f"Saved MVL plot for {atk} to {path}")
272
+
273
+ # Integrated MVL plot
274
+ plt.figure(figsize=(10,6))
275
+ for idx, atk in enumerate(names):
276
+ mvl_arr = np.array(results[atk]['mvl'])
277
+ mean_vals = np.mean(mvl_arr, axis=0)
278
+ std_vals = np.std(mvl_arr, axis=0)
279
+ layers = [f"Layer {i+1}" for i in range(len(mean_vals))]
280
+ plt.plot(layers, mean_vals, marker='o', color=colors[idx % len(colors)], label=atk)
281
+ plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[idx % len(colors)], alpha=0.3)
282
+ plt.title(f"Integrated MVL - {model_name}")
283
+ plt.ylabel("MVL (Mean ± Std)")
284
+ plt.xticks(rotation=45)
285
+ plt.legend()
286
+ plt.grid(True)
287
+ plt.tight_layout()
288
+ integrated_mvl_plot_path = os.path.join(output_dir, "integrated_mvl.png")
289
+ plt.savefig(integrated_mvl_plot_path)
290
+ plt.close()
291
+ logs.append(f"Saved integrated MVL plot to {integrated_mvl_plot_path}")
292
 
293
  processing_time = time.time() - start_time
294
 
295
  stats = {
296
  'Dataset': dataset_name,
297
  'Model': model_name,
298
+ 'Parameters': param_count,
299
+ 'Layers': layer_count,
300
+ 'Batches': num_batches,
301
+ 'Attacks': ', '.join(names),
302
+ 'Time (s)': round(processing_time, 2)
303
  }
304
+ stats_text = "## Model Statistics\n\n| Metric | Value |\n|---|---|\n"
305
  for k,v in stats.items():
306
  stats_text += f"| {k} | {v} |\n"
307
 
308
+ # Pad MVL plot paths to length 5 for UI consistency
309
+ while len(mvl_plot_paths) < 5:
310
+ mvl_plot_paths.append(None)
311
+
312
  return [None, cm_plot_path] + mvl_plot_paths[:5] + [integrated_mvl_plot_path, stats_text, '\n'.join(logs)]
313
 
314
  paper_info_html = """
 
336
  def update_models(dataset_name):
337
  models = get_models_for_dataset(dataset_name)
338
  default_value = models[0] if models else None
339
+ return models, default_value # Return choices and default value for older gradio versions
340
 
341
  def create_interface():
342
  datasets = ['MNIST', 'CIFAR-10']
 
368
  with gr.Tab("Logs"):
369
  log_output = gr.Textbox(label="Processing Logs")
370
 
 
371
  dataset_input.change(
372
  fn=update_models,
373
  inputs=dataset_input,
374
+ outputs=[model_input, model_input] # updates choices and default value
375
  )
376
 
377
  run_button.click(