khalooei
commited on
Commit
·
b2419d7
1
Parent(s):
ccadb41
update app
Browse files
app.py
CHANGED
@@ -178,15 +178,17 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
178 |
logs = ["BSM:: experiment is being started ..."]
|
179 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
180 |
|
|
|
181 |
dataset, _ = get_dataset_and_transform(dataset_name)
|
182 |
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
|
183 |
-
logs.append(f"{dataset_name} dataset loaded")
|
184 |
|
|
|
185 |
model = initialize_model(model_name, device)
|
186 |
-
logs.append(f"Model {model_name}
|
187 |
|
188 |
param_count, layer_count = get_model_stats(model)
|
189 |
-
logs.append(f"Model stats: {param_count}
|
190 |
|
191 |
all_attacks = {
|
192 |
'FGSM': FGSM(model, eps=0.03),
|
@@ -204,21 +206,27 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
204 |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
205 |
output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
|
206 |
os.makedirs(output_dir, exist_ok=True)
|
207 |
-
logs.append(f"Output directory: {output_dir}")
|
208 |
|
209 |
results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
|
210 |
|
211 |
for i, (images, labels) in enumerate(testloader):
|
212 |
if i >= num_batches:
|
|
|
213 |
break
|
214 |
images, labels = images.to(device), labels.to(device)
|
215 |
logs.append(f"Processing batch {i+1}/{num_batches}...")
|
216 |
|
217 |
for atk_name, atk in attacks.items():
|
|
|
218 |
adv_images = atk(images, labels)
|
219 |
mvl_vals = compute_mvl(model, images, adv_images, device)
|
220 |
results[atk_name]['mvl'].append(mvl_vals)
|
221 |
-
|
|
|
|
|
|
|
|
|
222 |
|
223 |
cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
|
224 |
cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
|
@@ -279,6 +287,7 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
279 |
logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}")
|
280 |
|
281 |
processing_time = time.time() - start_time
|
|
|
282 |
|
283 |
stats = {
|
284 |
'Dataset': dataset_name,
|
@@ -309,7 +318,7 @@ paper_info_html = """
|
|
309 |
<div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
|
310 |
<h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2>
|
311 |
<h3>Authors</h3>
|
312 |
-
<p>Mohammad Khalooei, Mohammad Mehdi
|
313 |
|
314 |
<h3>Abstract</h3>
|
315 |
<ul>
|
@@ -348,7 +357,7 @@ def create_interface():
|
|
348 |
model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
|
349 |
|
350 |
attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
|
351 |
-
batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=
|
352 |
run_button = gr.Button("Run Analysis")
|
353 |
|
354 |
error_output = gr.Textbox(label="Error", visible=False)
|
|
|
178 |
logs = ["BSM:: experiment is being started ..."]
|
179 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
180 |
|
181 |
+
logs.append(f"Loading {dataset_name} dataset...")
|
182 |
dataset, _ = get_dataset_and_transform(dataset_name)
|
183 |
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
|
184 |
+
logs.append(f"{dataset_name} dataset loaded with {len(testloader)} batches.")
|
185 |
|
186 |
+
logs.append(f"Initializing model {model_name} on {device}...")
|
187 |
model = initialize_model(model_name, device)
|
188 |
+
logs.append(f"Model {model_name} initialized.")
|
189 |
|
190 |
param_count, layer_count = get_model_stats(model)
|
191 |
+
logs.append(f"Model stats: Parameters = {param_count}, Layers = {layer_count}")
|
192 |
|
193 |
all_attacks = {
|
194 |
'FGSM': FGSM(model, eps=0.03),
|
|
|
206 |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
207 |
output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
|
208 |
os.makedirs(output_dir, exist_ok=True)
|
209 |
+
logs.append(f"Output directory created: {output_dir}")
|
210 |
|
211 |
results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
|
212 |
|
213 |
for i, (images, labels) in enumerate(testloader):
|
214 |
if i >= num_batches:
|
215 |
+
logs.append(f"Reached batch limit: {num_batches}")
|
216 |
break
|
217 |
images, labels = images.to(device), labels.to(device)
|
218 |
logs.append(f"Processing batch {i+1}/{num_batches}...")
|
219 |
|
220 |
for atk_name, atk in attacks.items():
|
221 |
+
logs.append(f" Running attack: {atk_name} on batch {i+1}")
|
222 |
adv_images = atk(images, labels)
|
223 |
mvl_vals = compute_mvl(model, images, adv_images, device)
|
224 |
results[atk_name]['mvl'].append(mvl_vals)
|
225 |
+
batch_cm = np.mean(mvl_vals)
|
226 |
+
results[atk_name]['cm'].append(batch_cm)
|
227 |
+
logs.append(f" Attack {atk_name}: batch CM={batch_cm:.6f}")
|
228 |
+
|
229 |
+
logs.append("Finished processing batches, computing statistics...")
|
230 |
|
231 |
cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
|
232 |
cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
|
|
|
287 |
logs.append(f"Saved integrated MVL plot: {integrated_mvl_plot_path}")
|
288 |
|
289 |
processing_time = time.time() - start_time
|
290 |
+
logs.append(f"Processing completed in {processing_time:.2f} seconds")
|
291 |
|
292 |
stats = {
|
293 |
'Dataset': dataset_name,
|
|
|
318 |
<div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
|
319 |
<h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2>
|
320 |
<h3>Authors</h3>
|
321 |
+
<p>Mohammad Khalooei, Mohammad Mehdi Homayounpour, Maryam Amirmazlaghani</p>
|
322 |
|
323 |
<h3>Abstract</h3>
|
324 |
<ul>
|
|
|
357 |
model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
|
358 |
|
359 |
attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
|
360 |
+
batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=2, label="Number of Batches")
|
361 |
run_button = gr.Button("Run Analysis")
|
362 |
|
363 |
error_output = gr.Textbox(label="Error", visible=False)
|