derekl35 HF Staff commited on
Commit
ff15687
·
verified ·
1 Parent(s): db4468d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -245
app.py CHANGED
@@ -35,24 +35,6 @@ def _save_agg_stats(stats: dict) -> None:
35
  with open(AGG_FILE, "w") as f:
36
  json.dump(stats, f, indent=2)
37
 
38
- USER_STATS_FILE = Path(__file__).parent / "user_stats.json"
39
- USER_STATS_LOCK_FILE = USER_STATS_FILE.with_suffix(".lock")
40
-
41
- def _load_user_stats() -> dict:
42
- if USER_STATS_FILE.exists():
43
- with open(USER_STATS_FILE, "r") as f:
44
- try:
45
- return json.load(f)
46
- except json.JSONDecodeError:
47
- print(f"Warning: {USER_STATS_FILE} is corrupted. Starting with empty user stats.")
48
- return {}
49
- return {}
50
-
51
- def _save_user_stats(stats: dict) -> None:
52
- with InterProcessLock(str(USER_STATS_LOCK_FILE)):
53
- with open(USER_STATS_FILE, "w") as f:
54
- json.dump(stats, f, indent=2)
55
-
56
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
  print(f"Using device: {DEVICE}")
58
 
@@ -62,7 +44,7 @@ DEFAULT_GUIDANCE_SCALE = 3.5
62
  DEFAULT_NUM_INFERENCE_STEPS = 15
63
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
64
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
65
- HF_DATASET_REPO_ID = "derekl35/flux-quant-challenge-submissions"
66
 
67
  CACHED_PIPES = {}
68
  def load_bf16_pipeline():
@@ -77,7 +59,8 @@ def load_bf16_pipeline():
77
  torch_dtype=torch.bfloat16,
78
  token=HF_TOKEN
79
  )
80
- pipe.to(DEVICE)
 
81
  end_time = time.time()
82
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
83
  print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
@@ -98,8 +81,8 @@ def load_bnb_8bit_pipeline():
98
  MODEL_ID,
99
  torch_dtype=torch.bfloat16
100
  )
101
- pipe.to(DEVICE)
102
- # pipe.enable_model_cpu_offload()
103
  end_time = time.time()
104
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
105
  print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
@@ -120,8 +103,8 @@ def load_bnb_4bit_pipeline():
120
  MODEL_ID,
121
  torch_dtype=torch.bfloat16
122
  )
123
- pipe.to(DEVICE)
124
- # pipe.enable_model_cpu_offload()
125
  end_time = time.time()
126
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
127
  print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
@@ -134,10 +117,10 @@ def load_bnb_4bit_pipeline():
134
  @spaces.GPU(duration=240)
135
  def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
136
  if not prompt:
137
- return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
138
 
139
  if not quantization_choice:
140
- return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
141
 
142
  if quantization_choice == "8-bit bnb":
143
  quantized_load_func = load_bnb_8bit_pipeline
@@ -146,7 +129,7 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
146
  quantized_load_func = load_bnb_4bit_pipeline
147
  quantized_label = "Quantized (4-bit bnb)"
148
  else:
149
- return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
150
 
151
  model_configs = [
152
  ("Original", load_bf16_pipeline),
@@ -188,17 +171,17 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
188
 
189
  except Exception as e:
190
  print(f"Error during {label} model processing: {e}")
191
- return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
192
 
193
 
194
  if len(results) != len(model_configs):
195
- return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
196
 
197
  shuffled_results = results.copy()
198
  random.shuffle(shuffled_results)
199
  shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)]
200
  correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)}
201
- # print("Correct mapping (hidden):", correct_mapping)
202
 
203
  return shuffled_data_for_gallery, correct_mapping, prompt, seed, results, "Generation complete! Make your guess.", None, gr.update(interactive=True), gr.update(interactive=True)
204
 
@@ -233,12 +216,14 @@ EXAMPLES = [
233
  "files": ["astronauts_seed_6456306350371904162.png", "astronauts_bnb_8bit.png"],
234
  "quantized_idx": 1,
235
  "quant_method": "8-bit bnb",
 
236
  },
237
  {
238
  "prompt": "Water-color painting of a cat wearing sunglasses",
239
  "files": ["watercolor_cat_bnb_8bit.png", "watercolor_cat_seed_14269059182221286790.png"],
240
  "quantized_idx": 0,
241
  "quant_method": "8-bit bnb",
 
242
  },
243
  # {
244
  # "prompt": "Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8-K",
@@ -261,13 +246,6 @@ def _accuracy_string(correct: int, attempts: int) -> tuple[str, float]:
261
  return f"{pct:.1f}%", pct
262
  return "N/A", -1.0
263
 
264
- def _add_medals(user_rows):
265
- MEDALS = {0: "🥇 ", 1: "🥈 ", 2: "🥉 "}
266
- return [
267
- [MEDALS.get(i, "") + row[0], *row[1:]]
268
- for i, row in enumerate(user_rows)
269
- ]
270
-
271
  def update_leaderboards_data():
272
  agg = _load_agg_stats()
273
  quant_rows = []
@@ -280,50 +258,12 @@ def update_leaderboards_data():
280
  acc_str
281
  ])
282
  quant_rows.sort(key=lambda r: r[1]/r[2] if r[2] != 0 else 1e9)
283
-
284
- user_stats_all = _load_user_stats()
285
-
286
- overall_user_rows = []
287
- for user, per_method_stats_dict in user_stats_all.items():
288
- user_total_correct = 0
289
- user_total_attempts = 0
290
- for method_stats in per_method_stats_dict.values():
291
- user_total_correct += method_stats.get("correct", 0)
292
- user_total_attempts += method_stats.get("attempts", 0)
293
-
294
- if user_total_attempts >= 1:
295
- acc_str, _ = _accuracy_string(user_total_correct, user_total_attempts)
296
- overall_user_rows.append([user, user_total_correct, user_total_attempts, acc_str])
297
-
298
- overall_user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2]))
299
- overall_user_rows_medaled = _add_medals(overall_user_rows)
300
-
301
- user_leaderboards_per_method = {}
302
- quant_method_names = list(agg.keys())
303
-
304
- for method_name in quant_method_names:
305
- method_specific_user_rows = []
306
- for user, per_user_method_stats_dict in user_stats_all.items():
307
- if method_name in per_user_method_stats_dict:
308
- st = per_user_method_stats_dict[method_name]
309
- if st.get("attempts", 0) >= 1: # Only include users who have attempted this method
310
- acc_str, _ = _accuracy_string(st["correct"], st["attempts"])
311
- method_specific_user_rows.append([user, st["correct"], st["attempts"], acc_str])
312
-
313
- method_specific_user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2]))
314
- method_specific_user_rows_medaled = _add_medals(method_specific_user_rows)
315
- user_leaderboards_per_method[method_name] = method_specific_user_rows_medaled
316
-
317
- return quant_rows, overall_user_rows_medaled, user_leaderboards_per_method
318
 
319
  quant_df = gr.DataFrame(
320
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
321
  interactive=False, col_count=(4, "fixed")
322
  )
323
- user_df = gr.DataFrame(
324
- headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
325
- interactive=False, col_count=(4, "fixed")
326
- )
327
 
328
  with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo:
329
  gr.Markdown("# FLUX Model Quantization Challenge")
@@ -337,7 +277,7 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
337
 
338
  gr.Markdown("### Examples")
339
  ex_selector = gr.Radio(
340
- choices=[f"Example {i+1}" for i in range(len(EXAMPLES))],
341
  label="Choose an example prompt",
342
  interactive=True,
343
  )
@@ -370,26 +310,16 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
370
 
371
  with gr.Row():
372
  session_score_box = gr.Textbox(label="Your accuracy this session", interactive=False)
373
-
374
- with gr.Row(equal_height=False):
375
- username_input = gr.Textbox(
376
- label="Enter Your Name for Leaderboard",
377
- placeholder="YourName",
378
- visible=False,
379
- interactive=True,
380
- scale=2
381
- )
382
- add_score_button = gr.Button(
383
- "Add My Score to Leaderboard",
384
- visible=False,
385
- variant="secondary",
386
- scale=1
387
- )
388
- add_score_feedback = gr.Textbox(
389
- label="Leaderboard Update",
390
- visible=False,
391
- interactive=False,
392
- lines=1
393
  )
394
 
395
  correct_mapping_state = gr.State({})
@@ -398,22 +328,26 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
398
  "4-bit bnb": {"attempts": 0, "correct": 0}}
399
  )
400
  is_example_state = gr.State(False)
401
- has_added_score_state = gr.State(False)
402
  prompt_state = gr.State("")
403
  seed_state = gr.State(None)
404
  results_state = gr.State([])
405
 
406
- def _load_example_and_update_dfs(sel):
407
- idx = int(sel.split()[-1]) - 1
 
 
 
 
 
408
  gallery_items, mapping, prompt = load_example(idx)
409
- quant_data, overall_user_data, _ = update_leaderboards_data()
410
- return gallery_items, mapping, prompt, True, quant_data, overall_user_data, "", None, []
411
 
412
  ex_selector.change(
413
  fn=_load_example_and_update_dfs,
414
  inputs=ex_selector,
415
- outputs=[output_gallery, correct_mapping_state, prompt_input, is_example_state, quant_df, user_df,
416
- prompt_state, seed_state, results_state],
417
  ).then(
418
  lambda: (gr.update(interactive=True), gr.update(interactive=True)),
419
  outputs=[image1_btn, image2_btn],
@@ -423,50 +357,39 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
423
  fn=generate_images,
424
  inputs=[prompt_input, quantization_choice_radio],
425
  outputs=[output_gallery, correct_mapping_state, prompt_state, seed_state, results_state,
426
- feedback_box] #, quantization_choice_radio, generate_button, prompt_input]
427
  ).then(
428
- lambda: (False, # for is_example_state
429
- False, # for has_added_score_state
430
- gr.update(visible=False, value="", interactive=True), # username_input reset
431
- gr.update(visible=False), # add_score_button reset
432
- gr.update(visible=False, value="")), # add_score_feedback reset
433
- outputs=[is_example_state,
434
- has_added_score_state,
435
- username_input,
436
- add_score_button,
437
- add_score_feedback]
438
  ).then(
439
  lambda: (gr.update(interactive=True),
440
- gr.update(interactive=True),
441
- ""),
442
  outputs=[image1_btn, image2_btn, feedback_box],
443
  )
444
 
445
- def choose(choice_string, mapping, session_stats, is_example, has_added_score_curr,
446
- prompt, seed, results, username):
447
  feedback = check_guess(choice_string, mapping)
448
 
449
  if not mapping:
450
- return feedback, gr.update(), gr.update(), "", session_stats, [], [], gr.update(), gr.update(), gr.update()
451
 
452
  quant_label_from_mapping = next((label for label in mapping.values() if "Quantized" in label), None)
453
  if not quant_label_from_mapping:
454
  print("Error: Could not determine quantization label from mapping:", mapping)
455
  return ("Internal Error: Could not process results.", gr.update(interactive=False), gr.update(interactive=False),
456
- "", session_stats, [], [], gr.update(), gr.update(), gr.update())
457
 
458
  quant_key = "8-bit bnb" if "8-bit bnb" in quant_label_from_mapping else "4-bit bnb"
459
-
460
  got_it_right = "Correct!" in feedback
461
-
462
  sess = session_stats.copy()
463
- should_log_and_update_stats = not is_example and not has_added_score_curr
464
 
465
- if should_log_and_update_stats:
466
  sess[quant_key]["attempts"] += 1
467
  if got_it_right:
468
  sess[quant_key]["correct"] += 1
469
- session_stats = sess
470
 
471
  AGG_STATS = _load_agg_stats()
472
  AGG_STATS[quant_key]["attempts"] += 1
@@ -478,6 +401,8 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
478
  print("Warning: HF_TOKEN not set. Skipping dataset logging.")
479
  elif not results:
480
  print("Warning: Results state is empty. Skipping dataset logging.")
 
 
481
  else:
482
  print(f"Logging guess to HF Dataset: {HF_DATASET_REPO_ID}")
483
  original_image = None
@@ -516,32 +441,22 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
516
  "quantized_image_displayed_position": [f"Image {quantized_image_pos + 1}"],
517
  "user_guess_displayed_position": [choice_string],
518
  "correct_guess": [got_it_right],
519
- "username": [username.strip() if username else None],
520
  }
521
-
522
  try:
523
- # Attempt to load existing dataset
524
  existing_ds = load_dataset(
525
  HF_DATASET_REPO_ID,
526
  split="train",
527
  token=HF_TOKEN,
528
  features=expected_features,
529
- # verification_mode="no_checks" # Consider removing or using default
530
- # download_mode="force_redownload" # For debugging cache issues
531
  )
532
- # Create a new dataset from the new item, casting to the expected features
533
  new_row_ds = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
534
- # Concatenate
535
  combined_ds = concatenate_datasets([existing_ds, new_row_ds])
536
- # Push the combined dataset
537
  combined_ds.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
538
  print(f"Successfully appended guess to {HF_DATASET_REPO_ID} (train split)")
539
-
540
  except Exception as e:
541
  print(f"Could not load or append to existing dataset/split. Creating 'train' split with the new item. Error: {e}")
542
- # Create dataset from only the new item, with explicit features
543
  ds_new = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
544
- # Push this new dataset as the 'train' split
545
  ds_new.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
546
  print(f"Successfully created and logged new 'train' split to {HF_DATASET_REPO_ID}")
547
  else:
@@ -555,136 +470,45 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
555
  session_msg = ", ".join(
556
  f"{k}: {_fmt(v)}" for k, v in sess.items()
557
  )
558
- current_agg_stats = _load_agg_stats()
559
-
560
- username_input_update = gr.update(visible=False, interactive=True)
561
- add_score_button_update = gr.update(visible=False)
562
- current_feedback_text = add_score_feedback.value if hasattr(add_score_feedback, 'value') and add_score_feedback.value else ""
563
- add_score_feedback_update = gr.update(visible=has_added_score_curr, value=current_feedback_text)
564
-
565
- session_total_attempts = sum(stats["attempts"] for stats in sess.values())
566
-
567
- if not is_example and not has_added_score_curr:
568
- if session_total_attempts >= 1 :
569
- username_input_update = gr.update(visible=True, interactive=True)
570
- add_score_button_update = gr.update(visible=True, interactive=True)
571
- add_score_feedback_update = gr.update(visible=False, value="")
572
- else:
573
- username_input_update = gr.update(visible=False, value=username_input.value if hasattr(username_input, 'value') else "")
574
- add_score_button_update = gr.update(visible=False)
575
- add_score_feedback_update = gr.update(visible=False, value="")
576
- elif has_added_score_curr:
577
- username_input_update = gr.update(visible=True, interactive=False, value=username_input.value if hasattr(username_input, 'value') else "")
578
- add_score_button_update = gr.update(visible=True, interactive=False)
579
- add_score_feedback_update = gr.update(visible=True)
580
-
581
- quant_data, overall_user_data, _ = update_leaderboards_data()
582
  return (feedback,
583
  gr.update(interactive=False),
584
  gr.update(interactive=False),
585
  session_msg,
586
- session_stats,
587
- quant_data,
588
- overall_user_data,
589
- username_input_update,
590
- add_score_button_update,
591
- add_score_feedback_update)
592
 
593
  image1_btn.click(
594
- fn=lambda mapping, sess, is_ex, has_added, p, s, r, uname: choose("Image 1", mapping, sess, is_ex, has_added, p, s, r, uname),
595
- inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state,
596
- prompt_state, seed_state, results_state, username_input],
597
  outputs=[feedback_box, image1_btn, image2_btn,
598
- session_score_box, session_stats_state,
599
- quant_df, user_df,
600
- username_input, add_score_button, add_score_feedback],
601
  )
602
  image2_btn.click(
603
- fn=lambda mapping, sess, is_ex, has_added, p, s, r, uname: choose("Image 2", mapping, sess, is_ex, has_added, p, s, r, uname),
604
- inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state,
605
- prompt_state, seed_state, results_state, username_input],
606
  outputs=[feedback_box, image1_btn, image2_btn,
607
- session_score_box, session_stats_state,
608
- quant_df, user_df,
609
- username_input, add_score_button, add_score_feedback],
610
  )
611
 
612
- def handle_add_score_to_leaderboard(username_str, current_session_stats_dict):
613
- if not username_str or not username_str.strip():
614
- return ("Username is required.",
615
- gr.update(interactive=True),
616
- gr.update(interactive=True),
617
- False,
618
- None, None)
619
-
620
- user_stats = _load_user_stats()
621
- user_key = username_str.strip()
622
-
623
- session_total_session_attempts = sum(stats["attempts"] for stats in current_session_stats_dict.values())
624
- if session_total_session_attempts == 0:
625
- return ("No attempts made in this session to add to leaderboard.",
626
- gr.update(interactive=True),
627
- gr.update(interactive=True),
628
- False, None, None)
629
-
630
- if user_key not in user_stats:
631
- user_stats[user_key] = {}
632
-
633
- for method, stats in current_session_stats_dict.items():
634
- session_method_correct = stats["correct"]
635
- session_method_attempts = stats["attempts"]
636
-
637
- if session_method_attempts == 0:
638
- continue
639
-
640
- if method not in user_stats[user_key]:
641
- user_stats[user_key][method] = {"correct": 0, "attempts": 0}
642
-
643
- user_stats[user_key][method]["correct"] += session_method_correct
644
- user_stats[user_key][method]["attempts"] += session_method_attempts
645
-
646
- _save_user_stats(user_stats)
647
-
648
- new_quant_data, new_overall_user_data, _ = update_leaderboards_data()
649
- feedback_msg = f"Score for '{user_key}' submitted to leaderboard!"
650
- return (feedback_msg,
651
- gr.update(interactive=False),
652
- gr.update(interactive=False),
653
- True,
654
- new_quant_data,
655
- new_overall_user_data)
656
- add_score_button.click(
657
- fn=handle_add_score_to_leaderboard,
658
- inputs=[username_input, session_stats_state],
659
- outputs=[add_score_feedback, username_input, add_score_button, has_added_score_state, quant_df, user_df]
660
- )
661
  with gr.TabItem("Leaderboard"):
662
  gr.Markdown("## Quantization Method Leaderboard *(Lower % ⇒ harder to detect)*")
663
  leaderboard_tab_quant_df = gr.DataFrame(
664
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
665
  interactive=False, col_count=(4, "fixed"), label="Quantization Method Leaderboard"
666
  )
667
- gr.Markdown("---")
668
-
669
- leaderboard_tab_user_df_8bit = gr.DataFrame(
670
- headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
671
- interactive=False, col_count=(4, "fixed"), label="8-bit bnb User Leaderboard"
672
- )
673
- leaderboard_tab_user_df_4bit = gr.DataFrame(
674
- headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
675
- interactive=False, col_count=(4, "fixed"), label="4-bit bnb User Leaderboard"
676
- )
677
 
678
  def update_all_leaderboards_for_tab():
679
- q_rows, _, per_method_u_dict = update_leaderboards_data()
680
- user_rows_8bit = per_method_u_dict.get("8-bit bnb", [])
681
- user_rows_4bit = per_method_u_dict.get("4-bit bnb", [])
682
- return q_rows, user_rows_8bit, user_rows_4bit
683
 
684
  demo.load(update_all_leaderboards_for_tab, outputs=[
685
- leaderboard_tab_quant_df,
686
- leaderboard_tab_user_df_8bit,
687
- leaderboard_tab_user_df_4bit
688
  ])
689
 
690
  if __name__ == "__main__":
 
35
  with open(AGG_FILE, "w") as f:
36
  json.dump(stats, f, indent=2)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
  print(f"Using device: {DEVICE}")
40
 
 
44
  DEFAULT_NUM_INFERENCE_STEPS = 15
45
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
46
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
47
+ HF_DATASET_REPO_ID = "diffusers/flux-quant-challenge-submissions"
48
 
49
  CACHED_PIPES = {}
50
  def load_bf16_pipeline():
 
59
  torch_dtype=torch.bfloat16,
60
  token=HF_TOKEN
61
  )
62
+ # pipe.to(DEVICE)
63
+ pipe.enable_model_cpu_offload()
64
  end_time = time.time()
65
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
66
  print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
 
81
  MODEL_ID,
82
  torch_dtype=torch.bfloat16
83
  )
84
+ # pipe.to(DEVICE)
85
+ pipe.enable_model_cpu_offload()
86
  end_time = time.time()
87
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
88
  print(f"8-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
 
103
  MODEL_ID,
104
  torch_dtype=torch.bfloat16
105
  )
106
+ # pipe.to(DEVICE)
107
+ pipe.enable_model_cpu_offload()
108
  end_time = time.time()
109
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
110
  print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
 
117
  @spaces.GPU(duration=240)
118
  def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
119
  if not prompt:
120
+ return None, {}, gr.update(value="Please enter a prompt.", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
121
 
122
  if not quantization_choice:
123
+ return None, {}, gr.update(value="Please select a quantization method.", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
124
 
125
  if quantization_choice == "8-bit bnb":
126
  quantized_load_func = load_bnb_8bit_pipeline
 
129
  quantized_load_func = load_bnb_4bit_pipeline
130
  quantized_label = "Quantized (4-bit bnb)"
131
  else:
132
+ return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
133
 
134
  model_configs = [
135
  ("Original", load_bf16_pipeline),
 
171
 
172
  except Exception as e:
173
  print(f"Error during {label} model processing: {e}")
174
+ return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
175
 
176
 
177
  if len(results) != len(model_configs):
178
+ return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), None, [], gr.update(interactive=True), gr.update(interactive=True)
179
 
180
  shuffled_results = results.copy()
181
  random.shuffle(shuffled_results)
182
  shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)]
183
  correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)}
184
+ print("Correct mapping (hidden):", correct_mapping)
185
 
186
  return shuffled_data_for_gallery, correct_mapping, prompt, seed, results, "Generation complete! Make your guess.", None, gr.update(interactive=True), gr.update(interactive=True)
187
 
 
216
  "files": ["astronauts_seed_6456306350371904162.png", "astronauts_bnb_8bit.png"],
217
  "quantized_idx": 1,
218
  "quant_method": "8-bit bnb",
219
+ "summary": "Astronaut on Mars",
220
  },
221
  {
222
  "prompt": "Water-color painting of a cat wearing sunglasses",
223
  "files": ["watercolor_cat_bnb_8bit.png", "watercolor_cat_seed_14269059182221286790.png"],
224
  "quantized_idx": 0,
225
  "quant_method": "8-bit bnb",
226
+ "summary": "Cat with Sunglasses",
227
  },
228
  # {
229
  # "prompt": "Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8-K",
 
246
  return f"{pct:.1f}%", pct
247
  return "N/A", -1.0
248
 
 
 
 
 
 
 
 
249
  def update_leaderboards_data():
250
  agg = _load_agg_stats()
251
  quant_rows = []
 
258
  acc_str
259
  ])
260
  quant_rows.sort(key=lambda r: r[1]/r[2] if r[2] != 0 else 1e9)
261
+ return quant_rows
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  quant_df = gr.DataFrame(
264
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
265
  interactive=False, col_count=(4, "fixed")
266
  )
 
 
 
 
267
 
268
  with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo:
269
  gr.Markdown("# FLUX Model Quantization Challenge")
 
277
 
278
  gr.Markdown("### Examples")
279
  ex_selector = gr.Radio(
280
+ choices=[ex["summary"] for ex in EXAMPLES],
281
  label="Choose an example prompt",
282
  interactive=True,
283
  )
 
310
 
311
  with gr.Row():
312
  session_score_box = gr.Textbox(label="Your accuracy this session", interactive=False)
313
+
314
+ gr.Markdown("""
315
+ ### Dataset Information
316
+ Unless you opt out below, your submissions will be recorded in a dataset. This dataset contains anonymized challenge results including prompts, images, quantization methods,
317
+ and whether guesses were correct.
318
+ """)
319
+
320
+ opt_out_checkbox = gr.Checkbox(
321
+ label="Opt out of data collection (don't record my submissions to the dataset)",
322
+ value=False
 
 
 
 
 
 
 
 
 
 
323
  )
324
 
325
  correct_mapping_state = gr.State({})
 
328
  "4-bit bnb": {"attempts": 0, "correct": 0}}
329
  )
330
  is_example_state = gr.State(False)
 
331
  prompt_state = gr.State("")
332
  seed_state = gr.State(None)
333
  results_state = gr.State([])
334
 
335
+ def _load_example_and_update_dfs(sel_summary):
336
+ idx = next((i for i, ex in enumerate(EXAMPLES) if ex["summary"] == sel_summary), -1)
337
+ if idx == -1:
338
+ print(f"Error: Example with summary '{sel_summary}' not found.")
339
+ return (gr.update(), gr.update(), gr.update(), False, gr.update(), "", None, [])
340
+
341
+ ex = EXAMPLES[idx]
342
  gallery_items, mapping, prompt = load_example(idx)
343
+ quant_data = update_leaderboards_data()
344
+ return gallery_items, mapping, prompt, True, quant_data, "", None, []
345
 
346
  ex_selector.change(
347
  fn=_load_example_and_update_dfs,
348
  inputs=ex_selector,
349
+ outputs=[output_gallery, correct_mapping_state, prompt_input, is_example_state, quant_df,
350
+ prompt_state, seed_state, results_state],
351
  ).then(
352
  lambda: (gr.update(interactive=True), gr.update(interactive=True)),
353
  outputs=[image1_btn, image2_btn],
 
357
  fn=generate_images,
358
  inputs=[prompt_input, quantization_choice_radio],
359
  outputs=[output_gallery, correct_mapping_state, prompt_state, seed_state, results_state,
360
+ feedback_box]
361
  ).then(
362
+ lambda: False, # for is_example_state
363
+ outputs=[is_example_state]
 
 
 
 
 
 
 
 
364
  ).then(
365
  lambda: (gr.update(interactive=True),
366
+ gr.update(interactive=True),
367
+ ""),
368
  outputs=[image1_btn, image2_btn, feedback_box],
369
  )
370
 
371
+ def choose(choice_string, mapping, session_stats, is_example,
372
+ prompt, seed, results, opt_out):
373
  feedback = check_guess(choice_string, mapping)
374
 
375
  if not mapping:
376
+ return feedback, gr.update(), gr.update(), "", session_stats, gr.update()
377
 
378
  quant_label_from_mapping = next((label for label in mapping.values() if "Quantized" in label), None)
379
  if not quant_label_from_mapping:
380
  print("Error: Could not determine quantization label from mapping:", mapping)
381
  return ("Internal Error: Could not process results.", gr.update(interactive=False), gr.update(interactive=False),
382
+ "", session_stats, gr.update())
383
 
384
  quant_key = "8-bit bnb" if "8-bit bnb" in quant_label_from_mapping else "4-bit bnb"
 
385
  got_it_right = "Correct!" in feedback
 
386
  sess = session_stats.copy()
 
387
 
388
+ if not is_example: # Only log and update stats if it's not an example run
389
  sess[quant_key]["attempts"] += 1
390
  if got_it_right:
391
  sess[quant_key]["correct"] += 1
392
+ session_stats = sess # Update the state for the UI
393
 
394
  AGG_STATS = _load_agg_stats()
395
  AGG_STATS[quant_key]["attempts"] += 1
 
401
  print("Warning: HF_TOKEN not set. Skipping dataset logging.")
402
  elif not results:
403
  print("Warning: Results state is empty. Skipping dataset logging.")
404
+ elif opt_out:
405
+ print("User opted out of dataset logging. Skipping.")
406
  else:
407
  print(f"Logging guess to HF Dataset: {HF_DATASET_REPO_ID}")
408
  original_image = None
 
441
  "quantized_image_displayed_position": [f"Image {quantized_image_pos + 1}"],
442
  "user_guess_displayed_position": [choice_string],
443
  "correct_guess": [got_it_right],
444
+ "username": [None], # Log None for username
445
  }
 
446
  try:
 
447
  existing_ds = load_dataset(
448
  HF_DATASET_REPO_ID,
449
  split="train",
450
  token=HF_TOKEN,
451
  features=expected_features,
 
 
452
  )
 
453
  new_row_ds = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
 
454
  combined_ds = concatenate_datasets([existing_ds, new_row_ds])
 
455
  combined_ds.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
456
  print(f"Successfully appended guess to {HF_DATASET_REPO_ID} (train split)")
 
457
  except Exception as e:
458
  print(f"Could not load or append to existing dataset/split. Creating 'train' split with the new item. Error: {e}")
 
459
  ds_new = Dataset.from_dict(new_data_dict_of_lists, features=expected_features)
 
460
  ds_new.push_to_hub(HF_DATASET_REPO_ID, token=HF_TOKEN, split="train")
461
  print(f"Successfully created and logged new 'train' split to {HF_DATASET_REPO_ID}")
462
  else:
 
470
  session_msg = ", ".join(
471
  f"{k}: {_fmt(v)}" for k, v in sess.items()
472
  )
473
+
474
+ quant_data = update_leaderboards_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  return (feedback,
476
  gr.update(interactive=False),
477
  gr.update(interactive=False),
478
  session_msg,
479
+ session_stats, # Return the potentially updated session_stats
480
+ quant_data)
 
 
 
 
481
 
482
  image1_btn.click(
483
+ fn=lambda mapping, sess, is_ex, p, s, r, opt_out: choose("Image 1", mapping, sess, is_ex, p, s, r, opt_out),
484
+ inputs=[correct_mapping_state, session_stats_state, is_example_state,
485
+ prompt_state, seed_state, results_state, opt_out_checkbox],
486
  outputs=[feedback_box, image1_btn, image2_btn,
487
+ session_score_box, session_stats_state,
488
+ quant_df],
 
489
  )
490
  image2_btn.click(
491
+ fn=lambda mapping, sess, is_ex, p, s, r, opt_out: choose("Image 2", mapping, sess, is_ex, p, s, r, opt_out),
492
+ inputs=[correct_mapping_state, session_stats_state, is_example_state,
493
+ prompt_state, seed_state, results_state, opt_out_checkbox],
494
  outputs=[feedback_box, image1_btn, image2_btn,
495
+ session_score_box, session_stats_state,
496
+ quant_df],
 
497
  )
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  with gr.TabItem("Leaderboard"):
500
  gr.Markdown("## Quantization Method Leaderboard *(Lower % ⇒ harder to detect)*")
501
  leaderboard_tab_quant_df = gr.DataFrame(
502
  headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
503
  interactive=False, col_count=(4, "fixed"), label="Quantization Method Leaderboard"
504
  )
 
 
 
 
 
 
 
 
 
 
505
 
506
  def update_all_leaderboards_for_tab():
507
+ q_rows = update_leaderboards_data()
508
+ return q_rows # Only return quantization method data
 
 
509
 
510
  demo.load(update_all_leaderboards_for_tab, outputs=[
511
+ leaderboard_tab_quant_df,
 
 
512
  ])
513
 
514
  if __name__ == "__main__":