khalooei commited on
Commit
b2419d7
·
1 Parent(s): ccadb41

update app

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -178,15 +178,17 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
178
  logs = ["BSM:: experiment is being started ..."]
179
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
180
 
 
181
  dataset, _ = get_dataset_and_transform(dataset_name)
182
  testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
183
- logs.append(f"{dataset_name} dataset loaded")
184
 
 
185
  model = initialize_model(model_name, device)
186
- logs.append(f"Model {model_name} loaded on {device}")
187
 
188
  param_count, layer_count = get_model_stats(model)
189
- logs.append(f"Model stats: {param_count} parameters, {layer_count} layers")
190
 
191
  all_attacks = {
192
  'FGSM': FGSM(model, eps=0.03),
@@ -204,21 +206,27 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
204
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
205
  output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
206
  os.makedirs(output_dir, exist_ok=True)
207
- logs.append(f"Output directory: {output_dir}")
208
 
209
  results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
210
 
211
  for i, (images, labels) in enumerate(testloader):
212
  if i >= num_batches:
 
213
  break
214
  images, labels = images.to(device), labels.to(device)
215
  logs.append(f"Processing batch {i+1}/{num_batches}...")
216
 
217
  for atk_name, atk in attacks.items():
 
218
  adv_images = atk(images, labels)
219
  mvl_vals = compute_mvl(model, images, adv_images, device)
220
  results[atk_name]['mvl'].append(mvl_vals)
221
- results[atk_name]['cm'].append(np.mean(mvl_vals))
 
 
 
 
222
 
223
  cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
224
  cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
@@ -279,6 +287,7 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
279
  logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}")
280
 
281
  processing_time = time.time() - start_time
 
282
 
283
  stats = {
284
  'Dataset': dataset_name,
@@ -309,7 +318,7 @@ paper_info_html = """
309
  <div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
310
  <h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2>
311
  <h3>Authors</h3>
312
- <p>Mohammad Khalooei, Mohammad Mehdi Homaypour, Maryam Amirmazlaghani</p>
313
 
314
  <h3>Abstract</h3>
315
  <ul>
@@ -348,7 +357,7 @@ def create_interface():
348
  model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
349
 
350
  attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
351
- batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Batches")
352
  run_button = gr.Button("Run Analysis")
353
 
354
  error_output = gr.Textbox(label="Error", visible=False)
 
178
  logs = ["BSM:: experiment is being started ..."]
179
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
180
 
181
+ logs.append(f"Loading {dataset_name} dataset...")
182
  dataset, _ = get_dataset_and_transform(dataset_name)
183
  testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
184
+ logs.append(f"{dataset_name} dataset loaded with {len(testloader)} batches.")
185
 
186
+ logs.append(f"Initializing model {model_name} on {device}...")
187
  model = initialize_model(model_name, device)
188
+ logs.append(f"Model {model_name} initialized.")
189
 
190
  param_count, layer_count = get_model_stats(model)
191
+ logs.append(f"Model stats: Parameters = {param_count}, Layers = {layer_count}")
192
 
193
  all_attacks = {
194
  'FGSM': FGSM(model, eps=0.03),
 
206
  timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
207
  output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
208
  os.makedirs(output_dir, exist_ok=True)
209
+ logs.append(f"Output directory created: {output_dir}")
210
 
211
  results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
212
 
213
  for i, (images, labels) in enumerate(testloader):
214
  if i >= num_batches:
215
+ logs.append(f"Reached batch limit: {num_batches}")
216
  break
217
  images, labels = images.to(device), labels.to(device)
218
  logs.append(f"Processing batch {i+1}/{num_batches}...")
219
 
220
  for atk_name, atk in attacks.items():
221
+ logs.append(f" Running attack: {atk_name} on batch {i+1}")
222
  adv_images = atk(images, labels)
223
  mvl_vals = compute_mvl(model, images, adv_images, device)
224
  results[atk_name]['mvl'].append(mvl_vals)
225
+ batch_cm = np.mean(mvl_vals)
226
+ results[atk_name]['cm'].append(batch_cm)
227
+ logs.append(f" Attack {atk_name}: batch CM={batch_cm:.6f}")
228
+
229
+ logs.append("Finished processing batches, computing statistics...")
230
 
231
  cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
232
  cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
 
287
  logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}")
288
 
289
  processing_time = time.time() - start_time
290
+ logs.append(f"Processing completed in {processing_time:.2f} seconds")
291
 
292
  stats = {
293
  'Dataset': dataset_name,
 
318
  <div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
319
  <h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2>
320
  <h3>Authors</h3>
321
+ <p>Mohammad Khalooei, Mohammad Mehdi Homayounpour, Maryam Amirmazlaghani</p>
322
 
323
  <h3>Abstract</h3>
324
  <ul>
 
357
  model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
358
 
359
  attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
360
+ batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=2, label="Number of Batches")
361
  run_button = gr.Button("Run Analysis")
362
 
363
  error_output = gr.Textbox(label="Error", visible=False)