DerekLiu35 commited on
Commit
ed34aa3
·
1 Parent(s): fd31dbe

add dataset

Browse files
Files changed (1) hide show
  1. app.py +216 -80
app.py CHANGED
@@ -11,6 +11,10 @@ import time
11
  import json
12
  from fasteners import InterProcessLock
13
  import spaces
 
 
 
 
14
 
15
  AGG_FILE = Path(__file__).parent / "agg_stats.json"
16
  LOCK_FILE = AGG_FILE.with_suffix(".lock")
@@ -22,9 +26,9 @@ def _load_agg_stats() -> dict:
22
  return json.load(f)
23
  except json.JSONDecodeError:
24
  print(f"Warning: {AGG_FILE} is corrupted. Starting with empty stats.")
25
- return {"8-bit": {"attempts": 0, "correct": 0}, "4-bit": {"attempts": 0, "correct": 0}}
26
- return {"8-bit": {"attempts": 0, "correct": 0},
27
- "4-bit": {"attempts": 0, "correct": 0}}
28
 
29
  def _save_agg_stats(stats: dict) -> None:
30
  with InterProcessLock(str(LOCK_FILE)):
@@ -58,6 +62,7 @@ DEFAULT_GUIDANCE_SCALE = 3.5
58
  DEFAULT_NUM_INFERENCE_STEPS = 15
59
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
60
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
 
61
 
62
  CACHED_PIPES = {}
63
  def load_bf16_pipeline():
@@ -134,12 +139,12 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
134
  if not quantization_choice:
135
  return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
136
 
137
- if quantization_choice == "8-bit":
138
  quantized_load_func = load_bnb_8bit_pipeline
139
- quantized_label = "Quantized (8-bit)"
140
- elif quantization_choice == "4-bit":
141
  quantized_load_func = load_bnb_4bit_pipeline
142
- quantized_label = "Quantized (4-bit)"
143
  else:
144
  return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
145
 
@@ -195,7 +200,7 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
195
  correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)}
196
  print("Correct mapping (hidden):", correct_mapping)
197
 
198
- return shuffled_data_for_gallery, correct_mapping, "Generation complete! Make your guess.", None, gr.update(interactive=True), gr.update(interactive=True)
199
 
200
 
201
  def check_guess(user_guess, correct_mapping_state):
@@ -227,13 +232,13 @@ EXAMPLES = [
227
  "prompt": "A photorealistic portrait of an astronaut on Mars",
228
  "files": ["astronauts_seed_6456306350371904162.png", "astronauts_bnb_8bit.png"],
229
  "quantized_idx": 1,
230
- "quant_method": "bnb 8-bit",
231
  },
232
  {
233
  "prompt": "Water-color painting of a cat wearing sunglasses",
234
  "files": ["watercolor_cat_bnb_8bit.png", "watercolor_cat_seed_14269059182221286790.png"],
235
  "quantized_idx": 0,
236
- "quant_method": "bnb 8-bit",
237
  },
238
  # {
239
  # "prompt": "Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8-K",
@@ -246,7 +251,7 @@ def load_example(idx):
246
  ex = EXAMPLES[idx]
247
  imgs = [Image.open(EXAMPLE_DIR / f) for f in ex["files"]]
248
  gallery_items = [(img, f"Image {i+1}") for i, img in enumerate(imgs)]
249
- mapping = {i: (f"Quantized {ex['quant_method']}" if i == ex["quantized_idx"] else "Original")
250
  for i in range(2)}
251
  return gallery_items, mapping, f"{ex['prompt']}"
252
 
@@ -276,15 +281,40 @@ def update_leaderboards_data():
276
  ])
277
  quant_rows.sort(key=lambda r: r[1]/r[2] if r[2] != 0 else 1e9)
278
 
279
- user_stats = _load_user_stats()
280
- user_rows = []
281
- for user, st in user_stats.items():
282
- acc_str, acc_val = _accuracy_string(st["total_correct"], st["total_attempts"])
283
- user_rows.append([user, st["total_correct"], st["total_attempts"], acc_str])
284
- user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2]))
285
- user_rows = _add_medals(user_rows)
286
-
287
- return quant_rows, user_rows
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  quant_df = gr.DataFrame(
290
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
@@ -300,7 +330,7 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
300
  with gr.Tabs():
301
  with gr.TabItem("Challenge"):
302
  gr.Markdown(
303
- "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). "
304
  "Enter a prompt, choose the quantization method, and generate two images. "
305
  "The images will be shuffled, can you spot which one was quantized?"
306
  )
@@ -315,9 +345,9 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
315
  with gr.Row():
316
  prompt_input = gr.Textbox(label="Enter Prompt", scale=3)
317
  quantization_choice_radio = gr.Radio(
318
- choices=["8-bit", "4-bit"],
319
  label="Select Quantization",
320
- value="8-bit",
321
  scale=1
322
  )
323
  generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
@@ -364,22 +394,26 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
364
 
365
  correct_mapping_state = gr.State({})
366
  session_stats_state = gr.State(
367
- {"8-bit": {"attempts": 0, "correct": 0},
368
- "4-bit": {"attempts": 0, "correct": 0}}
369
  )
370
  is_example_state = gr.State(False)
371
  has_added_score_state = gr.State(False)
 
 
 
372
 
373
- def _load_example(sel):
374
  idx = int(sel.split()[-1]) - 1
375
  gallery_items, mapping, prompt = load_example(idx)
376
- quant_data, user_data = update_leaderboards_data()
377
- return gallery_items, mapping, prompt, True, quant_data, user_data
378
 
379
  ex_selector.change(
380
- fn=_load_example,
381
  inputs=ex_selector,
382
- outputs=[output_gallery, correct_mapping_state, prompt_input, is_example_state, quant_df, user_df],
 
383
  ).then(
384
  lambda: (gr.update(interactive=True), gr.update(interactive=True)),
385
  outputs=[image1_btn, image2_btn],
@@ -388,7 +422,8 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
388
  generate_button.click(
389
  fn=generate_images,
390
  inputs=[prompt_input, quantization_choice_radio],
391
- outputs=[output_gallery, correct_mapping_state]
 
392
  ).then(
393
  lambda: (False, # for is_example_state
394
  False, # for has_added_score_state
@@ -407,16 +442,27 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
407
  outputs=[image1_btn, image2_btn, feedback_box],
408
  )
409
 
410
- def choose(choice_string, mapping, session_stats, is_example, has_added_score_curr):
 
411
  feedback = check_guess(choice_string, mapping)
412
 
413
- quant_label = next(label for label in mapping.values() if "Quantized" in label)
414
- quant_key = "8-bit" if "8-bit" in quant_label else "4-bit"
 
 
 
 
 
 
 
 
415
 
416
  got_it_right = "Correct!" in feedback
417
 
418
  sess = session_stats.copy()
419
- if not is_example and not has_added_score_curr:
 
 
420
  sess[quant_key]["attempts"] += 1
421
  if got_it_right:
422
  sess[quant_key]["correct"] += 1
@@ -428,6 +474,79 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
428
  AGG_STATS[quant_key]["correct"] += 1
429
  _save_agg_stats(AGG_STATS)
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  def _fmt(d):
432
  a, c = d["attempts"], d["correct"]
433
  pct = 100 * c / a if a else 0
@@ -437,24 +556,20 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
437
  f"{k}: {_fmt(v)}" for k, v in sess.items()
438
  )
439
  current_agg_stats = _load_agg_stats()
440
- global_msg = ", ".join(
441
- f"{k}: {_fmt(v)}" for k, v in current_agg_stats.items()
442
- )
443
 
444
  username_input_update = gr.update(visible=False, interactive=True)
445
  add_score_button_update = gr.update(visible=False)
446
- # Keep existing feedback if score was already added and feedback is visible
447
  current_feedback_text = add_score_feedback.value if hasattr(add_score_feedback, 'value') and add_score_feedback.value else ""
448
  add_score_feedback_update = gr.update(visible=has_added_score_curr, value=current_feedback_text)
449
 
450
  session_total_attempts = sum(stats["attempts"] for stats in sess.values())
451
 
452
  if not is_example and not has_added_score_curr:
453
- if session_total_attempts >= 1 : # Show button if more than 1 attempt
454
  username_input_update = gr.update(visible=True, interactive=True)
455
  add_score_button_update = gr.update(visible=True, interactive=True)
456
  add_score_feedback_update = gr.update(visible=False, value="")
457
- else: # Less than 1 attempts, keep hidden
458
  username_input_update = gr.update(visible=False, value=username_input.value if hasattr(username_input, 'value') else "")
459
  add_score_button_update = gr.update(visible=False)
460
  add_score_feedback_update = gr.update(visible=False, value="")
@@ -463,31 +578,31 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
463
  add_score_button_update = gr.update(visible=True, interactive=False)
464
  add_score_feedback_update = gr.update(visible=True)
465
 
466
- # disable the buttons so the user can't vote twice
467
- quant_data, user_data = update_leaderboards_data() # Get updated leaderboard data
468
  return (feedback,
469
  gr.update(interactive=False),
470
  gr.update(interactive=False),
471
  session_msg,
472
  session_stats,
473
  quant_data,
474
- user_data,
475
  username_input_update,
476
  add_score_button_update,
477
  add_score_feedback_update)
478
 
479
-
480
  image1_btn.click(
481
- fn=lambda mapping, sess, is_ex, has_added: choose("Image 1", mapping, sess, is_ex, has_added),
482
- inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state],
 
483
  outputs=[feedback_box, image1_btn, image2_btn,
484
  session_score_box, session_stats_state,
485
  quant_df, user_df,
486
  username_input, add_score_button, add_score_feedback],
487
  )
488
  image2_btn.click(
489
- fn=lambda mapping, sess, is_ex, has_added: choose("Image 2", mapping, sess, is_ex, has_added),
490
- inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state],
 
491
  outputs=[feedback_box, image1_btn, image2_btn,
492
  session_score_box, session_stats_state,
493
  quant_df, user_df,
@@ -496,43 +611,48 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
496
 
497
  def handle_add_score_to_leaderboard(username_str, current_session_stats_dict):
498
  if not username_str or not username_str.strip():
499
- return ("Username is required.", # Feedback for add_score_feedback
500
- gr.update(interactive=True), # username_input
501
- gr.update(interactive=True), # add_score_button
502
- False, # has_added_score_state
503
- None, None) # quant_df, user_df
504
 
505
  user_stats = _load_user_stats()
506
  user_key = username_str.strip()
507
 
508
- session_total_correct = sum(stats["correct"] for stats in current_session_stats_dict.values())
509
- session_total_attempts = sum(stats["attempts"] for stats in current_session_stats_dict.values())
510
-
511
- if session_total_attempts == 0:
512
  return ("No attempts made in this session to add to leaderboard.",
513
  gr.update(interactive=True),
514
  gr.update(interactive=True),
515
  False, None, None)
516
 
517
- if user_key in user_stats:
518
- user_stats[user_key]["total_correct"] += session_total_correct
519
- user_stats[user_key]["total_attempts"] += session_total_attempts
520
- else:
521
- user_stats[user_key] = {
522
- "total_correct": session_total_correct,
523
- "total_attempts": session_total_attempts
524
- }
 
 
 
 
 
 
 
 
525
  _save_user_stats(user_stats)
526
 
527
- new_quant_data, new_user_data = update_leaderboards_data()
528
  feedback_msg = f"Score for '{user_key}' submitted to leaderboard!"
529
- return (feedback_msg, # To add_score_feedback
530
- gr.update(interactive=False), # username_input
531
- gr.update(interactive=False), # add_score_button
532
- True, # has_added_score_state (set to true)
533
- new_quant_data, # To quant_df
534
- new_user_data) # To user_df
535
-
536
  add_score_button.click(
537
  fn=handle_add_score_to_leaderboard,
538
  inputs=[username_input, session_stats_state],
@@ -540,16 +660,32 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
540
  )
541
  with gr.TabItem("Leaderboard"):
542
  gr.Markdown("## Quantization Method Leaderboard *(Lower % ⇒ harder to detect)*")
543
- quant_df = gr.DataFrame(
544
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
545
- interactive=False, col_count=(4, "fixed")
546
  )
547
- gr.Markdown("## User Leaderboard *(Higher % ⇒ better spotter)*")
548
- user_df = gr.DataFrame(
 
549
  headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
550
- interactive=False, col_count=(4, "fixed")
551
  )
552
- demo.load(update_leaderboards_data, outputs=[quant_df, user_df])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
554
  if __name__ == "__main__":
555
  demo.launch(share=True)
 
11
  import json
12
  from fasteners import InterProcessLock
13
  import spaces
14
+ from datasets import Dataset, Image as HFImage, load_dataset
15
+ from datasets import Features, Value
16
+ from datasets import concatenate_datasets
17
+ from datetime import datetime
18
 
19
  AGG_FILE = Path(__file__).parent / "agg_stats.json"
20
  LOCK_FILE = AGG_FILE.with_suffix(".lock")
 
26
  return json.load(f)
27
  except json.JSONDecodeError:
28
  print(f"Warning: {AGG_FILE} is corrupted. Starting with empty stats.")
29
+ return {"8-bit bnb": {"attempts": 0, "correct": 0}, "4-bit bnb": {"attempts": 0, "correct": 0}}
30
+ return {"8-bit bnb": {"attempts": 0, "correct": 0},
31
+ "4-bit bnb": {"attempts": 0, "correct": 0}}
32
 
33
  def _save_agg_stats(stats: dict) -> None:
34
  with InterProcessLock(str(LOCK_FILE)):
 
62
  DEFAULT_NUM_INFERENCE_STEPS = 15
63
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
64
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
65
+ HF_DATASET_REPO_ID = "derekl35/flux-quant-challenge-submissions"
66
 
67
  CACHED_PIPES = {}
68
  def load_bf16_pipeline():
 
139
  if not quantization_choice:
140
  return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
141
 
142
+ if quantization_choice == "8-bit bnb":
143
  quantized_load_func = load_bnb_8bit_pipeline
144
+ quantized_label = "Quantized (8-bit bnb)"
145
+ elif quantization_choice == "4-bit bnb":
146
  quantized_load_func = load_bnb_4bit_pipeline
147
+ quantized_label = "Quantized (4-bit bnb)"
148
  else:
149
  return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
150
 
 
200
  correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)}
201
  print("Correct mapping (hidden):", correct_mapping)
202
 
203
+ return shuffled_data_for_gallery, correct_mapping, prompt, seed, results, "Generation complete! Make your guess.", None, gr.update(interactive=True), gr.update(interactive=True)
204
 
205
 
206
  def check_guess(user_guess, correct_mapping_state):
 
232
  "prompt": "A photorealistic portrait of an astronaut on Mars",
233
  "files": ["astronauts_seed_6456306350371904162.png", "astronauts_bnb_8bit.png"],
234
  "quantized_idx": 1,
235
+ "quant_method": "8-bit bnb",
236
  },
237
  {
238
  "prompt": "Water-color painting of a cat wearing sunglasses",
239
  "files": ["watercolor_cat_bnb_8bit.png", "watercolor_cat_seed_14269059182221286790.png"],
240
  "quantized_idx": 0,
241
+ "quant_method": "8-bit bnb",
242
  },
243
  # {
244
  # "prompt": "Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8-K",
 
251
  ex = EXAMPLES[idx]
252
  imgs = [Image.open(EXAMPLE_DIR / f) for f in ex["files"]]
253
  gallery_items = [(img, f"Image {i+1}") for i, img in enumerate(imgs)]
254
+ mapping = {i: (f"Quantized ({ex['quant_method']})" if i == ex["quantized_idx"] else "Original")
255
  for i in range(2)}
256
  return gallery_items, mapping, f"{ex['prompt']}"
257
 
 
281
  ])
282
  quant_rows.sort(key=lambda r: r[1]/r[2] if r[2] != 0 else 1e9)
283
 
284
+ user_stats_all = _load_user_stats()
285
+
286
+ overall_user_rows = []
287
+ for user, per_method_stats_dict in user_stats_all.items():
288
+ user_total_correct = 0
289
+ user_total_attempts = 0
290
+ for method_stats in per_method_stats_dict.values():
291
+ user_total_correct += method_stats.get("correct", 0)
292
+ user_total_attempts += method_stats.get("attempts", 0)
293
+
294
+ if user_total_attempts >= 1:
295
+ acc_str, _ = _accuracy_string(user_total_correct, user_total_attempts)
296
+ overall_user_rows.append([user, user_total_correct, user_total_attempts, acc_str])
297
+
298
+ overall_user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2]))
299
+ overall_user_rows_medaled = _add_medals(overall_user_rows)
300
+
301
+ user_leaderboards_per_method = {}
302
+ quant_method_names = list(agg.keys())
303
+
304
+ for method_name in quant_method_names:
305
+ method_specific_user_rows = []
306
+ for user, per_user_method_stats_dict in user_stats_all.items():
307
+ if method_name in per_user_method_stats_dict:
308
+ st = per_user_method_stats_dict[method_name]
309
+ if st.get("attempts", 0) >= 1: # Only include users who have attempted this method
310
+ acc_str, _ = _accuracy_string(st["correct"], st["attempts"])
311
+ method_specific_user_rows.append([user, st["correct"], st["attempts"], acc_str])
312
+
313
+ method_specific_user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2]))
314
+ method_specific_user_rows_medaled = _add_medals(method_specific_user_rows)
315
+ user_leaderboards_per_method[method_name] = method_specific_user_rows_medaled
316
+
317
+ return quant_rows, overall_user_rows_medaled, user_leaderboards_per_method
318
 
319
  quant_df = gr.DataFrame(
320
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
 
330
  with gr.Tabs():
331
  with gr.TabItem("Challenge"):
332
  gr.Markdown(
333
+ "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit bnb). "
334
  "Enter a prompt, choose the quantization method, and generate two images. "
335
  "The images will be shuffled, can you spot which one was quantized?"
336
  )
 
345
  with gr.Row():
346
  prompt_input = gr.Textbox(label="Enter Prompt", scale=3)
347
  quantization_choice_radio = gr.Radio(
348
+ choices=["8-bit bnb", "4-bit bnb"],
349
  label="Select Quantization",
350
+ value="8-bit bnb",
351
  scale=1
352
  )
353
  generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
 
394
 
395
  correct_mapping_state = gr.State({})
396
  session_stats_state = gr.State(
397
+ {"8-bit bnb": {"attempts": 0, "correct": 0},
398
+ "4-bit bnb": {"attempts": 0, "correct": 0}}
399
  )
400
  is_example_state = gr.State(False)
401
  has_added_score_state = gr.State(False)
402
+ prompt_state = gr.State("")
403
+ seed_state = gr.State(None)
404
+ results_state = gr.State([])
405
 
406
+ def _load_example_and_update_dfs(sel):
407
  idx = int(sel.split()[-1]) - 1
408
  gallery_items, mapping, prompt = load_example(idx)
409
+ quant_data, overall_user_data, _ = update_leaderboards_data()
410
+ return gallery_items, mapping, prompt, True, quant_data, overall_user_data, "", None, []
411
 
412
  ex_selector.change(
413
+ fn=_load_example_and_update_dfs,
414
  inputs=ex_selector,
415
+ outputs=[output_gallery, correct_mapping_state, prompt_input, is_example_state, quant_df, user_df,
416
+ prompt_state, seed_state, results_state],
417
  ).then(
418
  lambda: (gr.update(interactive=True), gr.update(interactive=True)),
419
  outputs=[image1_btn, image2_btn],
 
422
  generate_button.click(
423
  fn=generate_images,
424
  inputs=[prompt_input, quantization_choice_radio],
425
+ outputs=[output_gallery, correct_mapping_state, prompt_state, seed_state, results_state,
426
+ feedback_box] #, quantization_choice_radio, generate_button, prompt_input]
427
  ).then(
428
  lambda: (False, # for is_example_state
429
  False, # for has_added_score_state
 
442
  outputs=[image1_btn, image2_btn, feedback_box],
443
  )
444
 
445
+ def choose(choice_string, mapping, session_stats, is_example, has_added_score_curr,
446
+ prompt, seed, results, username):
447
  feedback = check_guess(choice_string, mapping)
448
 
449
+ if not mapping:
450
+ return feedback, gr.update(), gr.update(), "", session_stats, [], [], gr.update(), gr.update(), gr.update()
451
+
452
+ quant_label_from_mapping = next((label for label in mapping.values() if "Quantized" in label), None)
453
+ if not quant_label_from_mapping:
454
+ print("Error: Could not determine quantization label from mapping:", mapping)
455
+ return ("Internal Error: Could not process results.", gr.update(interactive=False), gr.update(interactive=False),
456
+ "", session_stats, [], [], gr.update(), gr.update(), gr.update())
457
+
458
+ quant_key = "8-bit bnb" if "8-bit bnb" in quant_label_from_mapping else "4-bit bnb"
459
 
460
  got_it_right = "Correct!" in feedback
461
 
462
  sess = session_stats.copy()
463
+ should_log_and_update_stats = not is_example and not has_added_score_curr
464
+
465
+ if should_log_and_update_stats:
466
  sess[quant_key]["attempts"] += 1
467
  if got_it_right:
468
  sess[quant_key]["correct"] += 1
 
474
  AGG_STATS[quant_key]["correct"] += 1
475
  _save_agg_stats(AGG_STATS)
476
 
477
+ if not HF_TOKEN:
478
+ print("Warning: HF_TOKEN not set. Skipping dataset logging.")
479
+ elif not results:
480
+ print("Warning: Results state is empty. Skipping dataset logging.")
481
+ else:
482
+ print(f"Logging guess to HF Dataset: {HF_DATASET_REPO_ID}")
483
+ original_image = None
484
+ quantized_image = None
485
+ quantized_image_pos = -1
486
+
487
+ for shuffled_idx, original_label in mapping.items():
488
+ if "Quantized" in original_label:
489
+ quantized_image_pos = shuffled_idx
490
+ break
491
+
492
+ original_image = next((res["image"] for res in results if "Original" in res["label"]), None)
493
+ quantized_image = next((res["image"] for res in results if "Quantized" in res["label"]), None)
494
+
495
+ if original_image and quantized_image:
496
+ expected_features = Features({
497
+ "timestamp": Value("string"),
498
+ "prompt": Value("string"),
499
+ "quantization_method": Value("string"),
500
+ "seed": Value("string"),
501
+ "image_original": HFImage(),
502
+ "image_quantized": HFImage(),
503
+ "quantized_image_displayed_position": Value("string"),
504
+ "user_guess_displayed_position": Value("string"),
505
+ "correct_guess": Value("bool"),
506
+ "username": Value("string"), # Handles None
507
+ })
508
+
509
+ new_data_dict_of_lists = {
510
+ "timestamp": [datetime.now().isoformat()],
511
+ "prompt": [prompt],
512
+ "quantization_method": [quant_key],
513
+ "seed": [str(seed)],
514
+ "image_original": [original_image],
515
+ "image_quantized": [quantized_image],
516
+ "quantized_image_displayed_position": [f"Image {quantized_image_pos + 1}"],
517
+ "user_guess_displayed_position": [choice_string],
518
+ "correct_guess": [got_it_right],
519
+ "username": [username.strip() if username else None],
520
+ }
521
+
522
+ try:
523
+ # Attempt to load existing dataset
524
+ existing_ds = load_dataset(
525
+ HF_DATASET_REPO_ID,
526
+ split="train",
527
+ token=HF_TOKEN,
528
+ features=expected_features,
529
+ # verification_mode="no_checks" # Consider removing or using default
530
+ # download_mode="force_redownload" # For debugging cache issues
531
+ )
532
+ # Create a new dataset from the new item, casting to the expected features
533
+ new_row_ds = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
534
+ # Concatenate
535
+ combined_ds = concatenate_datasets([existing_ds, new_row_ds])
536
+ # Push the combined dataset
537
+ combined_ds.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
538
+ print(f"Successfully appended guess to {HF_DATASET_REPO_ID} (train split)")
539
+
540
+ except Exception as e:
541
+ print(f"Could not load or append to existing dataset/split. Creating 'train' split with the new item. Error: {e}")
542
+ # Create dataset from only the new item, with explicit features
543
+ ds_new = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
544
+ # Push this new dataset as the 'train' split
545
+ ds_new.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
546
+ print(f"Successfully created and logged new 'train' split to {HF_DATASET_REPO_ID}")
547
+ else:
548
+ print("Error: Could not find original or quantized image in results state for logging.")
549
+
550
  def _fmt(d):
551
  a, c = d["attempts"], d["correct"]
552
  pct = 100 * c / a if a else 0
 
556
  f"{k}: {_fmt(v)}" for k, v in sess.items()
557
  )
558
  current_agg_stats = _load_agg_stats()
 
 
 
559
 
560
  username_input_update = gr.update(visible=False, interactive=True)
561
  add_score_button_update = gr.update(visible=False)
 
562
  current_feedback_text = add_score_feedback.value if hasattr(add_score_feedback, 'value') and add_score_feedback.value else ""
563
  add_score_feedback_update = gr.update(visible=has_added_score_curr, value=current_feedback_text)
564
 
565
  session_total_attempts = sum(stats["attempts"] for stats in sess.values())
566
 
567
  if not is_example and not has_added_score_curr:
568
+ if session_total_attempts >= 1 :
569
  username_input_update = gr.update(visible=True, interactive=True)
570
  add_score_button_update = gr.update(visible=True, interactive=True)
571
  add_score_feedback_update = gr.update(visible=False, value="")
572
+ else:
573
  username_input_update = gr.update(visible=False, value=username_input.value if hasattr(username_input, 'value') else "")
574
  add_score_button_update = gr.update(visible=False)
575
  add_score_feedback_update = gr.update(visible=False, value="")
 
578
  add_score_button_update = gr.update(visible=True, interactive=False)
579
  add_score_feedback_update = gr.update(visible=True)
580
 
581
+ quant_data, overall_user_data, _ = update_leaderboards_data()
 
582
  return (feedback,
583
  gr.update(interactive=False),
584
  gr.update(interactive=False),
585
  session_msg,
586
  session_stats,
587
  quant_data,
588
+ overall_user_data,
589
  username_input_update,
590
  add_score_button_update,
591
  add_score_feedback_update)
592
 
 
593
  image1_btn.click(
594
+ fn=lambda mapping, sess, is_ex, has_added, p, s, r, uname: choose("Image 1", mapping, sess, is_ex, has_added, p, s, r, uname),
595
+ inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state,
596
+ prompt_state, seed_state, results_state, username_input],
597
  outputs=[feedback_box, image1_btn, image2_btn,
598
  session_score_box, session_stats_state,
599
  quant_df, user_df,
600
  username_input, add_score_button, add_score_feedback],
601
  )
602
  image2_btn.click(
603
+ fn=lambda mapping, sess, is_ex, has_added, p, s, r, uname: choose("Image 2", mapping, sess, is_ex, has_added, p, s, r, uname),
604
+ inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state,
605
+ prompt_state, seed_state, results_state, username_input],
606
  outputs=[feedback_box, image1_btn, image2_btn,
607
  session_score_box, session_stats_state,
608
  quant_df, user_df,
 
611
 
612
  def handle_add_score_to_leaderboard(username_str, current_session_stats_dict):
613
  if not username_str or not username_str.strip():
614
+ return ("Username is required.",
615
+ gr.update(interactive=True),
616
+ gr.update(interactive=True),
617
+ False,
618
+ None, None)
619
 
620
  user_stats = _load_user_stats()
621
  user_key = username_str.strip()
622
 
623
+ session_total_session_attempts = sum(stats["attempts"] for stats in current_session_stats_dict.values())
624
+ if session_total_session_attempts == 0:
 
 
625
  return ("No attempts made in this session to add to leaderboard.",
626
  gr.update(interactive=True),
627
  gr.update(interactive=True),
628
  False, None, None)
629
 
630
+ if user_key not in user_stats:
631
+ user_stats[user_key] = {}
632
+
633
+ for method, stats in current_session_stats_dict.items():
634
+ session_method_correct = stats["correct"]
635
+ session_method_attempts = stats["attempts"]
636
+
637
+ if session_method_attempts == 0:
638
+ continue
639
+
640
+ if method not in user_stats[user_key]:
641
+ user_stats[user_key][method] = {"correct": 0, "attempts": 0}
642
+
643
+ user_stats[user_key][method]["correct"] += session_method_correct
644
+ user_stats[user_key][method]["attempts"] += session_method_attempts
645
+
646
  _save_user_stats(user_stats)
647
 
648
+ new_quant_data, new_overall_user_data, _ = update_leaderboards_data()
649
  feedback_msg = f"Score for '{user_key}' submitted to leaderboard!"
650
+ return (feedback_msg,
651
+ gr.update(interactive=False),
652
+ gr.update(interactive=False),
653
+ True,
654
+ new_quant_data,
655
+ new_overall_user_data)
 
656
  add_score_button.click(
657
  fn=handle_add_score_to_leaderboard,
658
  inputs=[username_input, session_stats_state],
 
660
  )
661
  with gr.TabItem("Leaderboard"):
662
  gr.Markdown("## Quantization Method Leaderboard *(Lower % ⇒ harder to detect)*")
663
+ leaderboard_tab_quant_df = gr.DataFrame(
664
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
665
+ interactive=False, col_count=(4, "fixed"), label="Quantization Method Leaderboard"
666
  )
667
+ gr.Markdown("---")
668
+
669
+ leaderboard_tab_user_df_8bit = gr.DataFrame(
670
  headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
671
+ interactive=False, col_count=(4, "fixed"), label="8-bit bnb User Leaderboard"
672
  )
673
+ leaderboard_tab_user_df_4bit = gr.DataFrame(
674
+ headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
675
+ interactive=False, col_count=(4, "fixed"), label="4-bit bnb User Leaderboard"
676
+ )
677
+
678
+ def update_all_leaderboards_for_tab():
679
+ q_rows, _, per_method_u_dict = update_leaderboards_data()
680
+ user_rows_8bit = per_method_u_dict.get("8-bit bnb", [])
681
+ user_rows_4bit = per_method_u_dict.get("4-bit bnb", [])
682
+ return q_rows, user_rows_8bit, user_rows_4bit
683
+
684
+ demo.load(update_all_leaderboards_for_tab, outputs=[
685
+ leaderboard_tab_quant_df,
686
+ leaderboard_tab_user_df_8bit,
687
+ leaderboard_tab_user_df_4bit
688
+ ])
689
 
690
  if __name__ == "__main__":
691
  demo.launch(share=True)