Yeefei commited on
Commit
c2e2c90
·
verified ·
1 Parent(s): c565e5a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -187
app.py CHANGED
@@ -427,149 +427,149 @@ def infer_chest_cf(*args):
427
 
428
  with gr.Blocks(theme=gr.themes.Default()) as demo:
429
  with gr.Tabs():
430
- with gr.TabItem("Morpho-MNIST") as mnist_tab:
431
- mnist_id = gr.Textbox(value=mnist_tab.label, visible=False)
432
-
433
- with gr.Row(): #.style(equal_height=True):
434
- idx = gr.Number(value=0, visible=False)
435
- with gr.Column(scale=1, min_width=200):
436
- x = gr.Image(label="Observation", interactive=False, height=HEIGHT) #.style(
437
- #height=HEIGHT
438
- #)
439
- with gr.Column(scale=1, min_width=200):
440
- cf_x = gr.Image(label="Counterfactual", interactive=False, height=HEIGHT) #).style(
441
- # height=HEIGHT
442
- # )
443
- with gr.Column(scale=1, min_width=200):
444
- cf_x_std = gr.Image(
445
- label="Counterfactual Uncertainty", interactive=False
446
- , height=HEIGHT) #).style(height=HEIGHT)
447
- with gr.Column(scale=1, min_width=200):
448
- effect = gr.Image(
449
- label="Direct Causal Effect", interactive=False
450
- , height=HEIGHT) #).style(height=HEIGHT)
451
- with gr.Row(): #.style(equal_height=True):
452
- with gr.Column(scale=1):#.75):
453
- gr.Markdown(
454
- "**Intervention**"
455
- + 20 * " "
456
- + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
457
- + "  |   Hint: try 90% zoom"
458
- )
459
- with gr.Column():
460
- do_y = gr.Checkbox(label="do(digit)", value=False)
461
- y = gr.Radio(DIGITS, label="", interactive=False)
462
- with gr.Row():
463
- with gr.Column(min_width=100):
464
- do_t = gr.Checkbox(label="do(thickness)", value=False)
465
- t = gr.Slider(
466
- label="\u00A0",
467
- minimum=0.9,
468
- maximum=5.5,
469
- step=0.01,
470
- interactive=False,
471
- )
472
- with gr.Column(min_width=100):
473
- do_i = gr.Checkbox(label="do(intensity)", value=False)
474
- i = gr.Slider(
475
- label="\u00A0",
476
- minimum=50,
477
- maximum=255,
478
- step=0.01,
479
- interactive=False,
480
- )
481
- with gr.Row():
482
- new = gr.Button("New Observation")
483
- reset = gr.Button("Reset", variant="stop")
484
- submit = gr.Button("Submit", variant="primary")
485
- with gr.Column(scale=1):
486
- gr.Markdown("###  ")
487
- causal_graph = gr.Image(
488
- label="Causal Graph", interactive=False
489
- , height=300) #).style(height=300)
490
-
491
- with gr.TabItem("Brain MRI") as brain_tab:
492
- brain_id = gr.Textbox(value=brain_tab.label, visible=False)
493
-
494
- with gr.Row(): #.style(equal_height=True):
495
- idx_brain = gr.Number(value=0, visible=False)
496
- with gr.Column(scale=1, min_width=200):
497
- x_brain = gr.Image(label="Observation", interactive=False, height=HEIGHT) #).style(
498
- # height=HEIGHT
499
- # )
500
- with gr.Column(scale=1, min_width=200):
501
- cf_x_brain = gr.Image(
502
- label="Counterfactual", interactive=False
503
- , height=HEIGHT) #).style(height=HEIGHT)
504
- with gr.Column(scale=1, min_width=200):
505
- cf_x_std_brain = gr.Image(
506
- label="Counterfactual Uncertainty", interactive=False
507
- , height=HEIGHT) #).style(height=HEIGHT)
508
- with gr.Column(scale=1, min_width=200):
509
- effect_brain = gr.Image(
510
- label="Direct Causal Effect", interactive=False
511
- , height=HEIGHT) #).style(height=HEIGHT)
512
- with gr.Row():
513
- with gr.Column(scale=2):#.55):
514
- gr.Markdown(
515
- "**Intervention**"
516
- + 20 * " "
517
- + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
518
- + "  |   Hint: try 90% zoom"
519
- )
520
- with gr.Row():
521
- with gr.Column(min_width=200):
522
- do_m = gr.Checkbox(label="do(MRI sequence)", value=False)
523
- m = gr.Radio(
524
- ["T1", "T2-FLAIR"], label="", interactive=False
525
- )
526
- with gr.Column(min_width=200):
527
- do_s = gr.Checkbox(label="do(sex)", value=False)
528
- s = gr.Radio(
529
- ["female", "male"], label="", interactive=False
530
- )
531
- with gr.Row():
532
- with gr.Column(min_width=100):
533
- do_a = gr.Checkbox(label="do(age)", value=False)
534
- a = gr.Slider(
535
- label="\u00A0",
536
- value=50,
537
- minimum=44,
538
- maximum=73,
539
- step=1,
540
- interactive=False,
541
- )
542
- with gr.Column(min_width=100):
543
- do_b = gr.Checkbox(label="do(brain volume)", value=False)
544
- b = gr.Slider(
545
- label="\u00A0",
546
- value=1000,
547
- minimum=850,
548
- maximum=1550,
549
- step=20,
550
- interactive=False,
551
- )
552
- with gr.Column(min_width=100):
553
- do_v = gr.Checkbox(
554
- label="do(ventricle volume)", value=False
555
- )
556
- v = gr.Slider(
557
- label="\u00A0",
558
- value=40,
559
- minimum=10,
560
- maximum=125,
561
- step=2,
562
- interactive=False,
563
- )
564
- with gr.Row():
565
- new_brain = gr.Button("New Observation")
566
- reset_brain = gr.Button("Reset", variant="stop")
567
- submit_brain = gr.Button("Submit", variant="primary")
568
- with gr.Column(scale=1):
569
- # gr.Markdown("###  ")
570
- causal_graph_brain = gr.Image(
571
- label="Causal Graph", interactive=False
572
- , height=340) #).style(height=340)
573
 
574
  with gr.TabItem("Chest X-ray") as chest_tab:
575
  chest_id = gr.Textbox(value=chest_tab.label, visible=False)
@@ -647,47 +647,53 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
647
  cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest]
648
 
649
  # on start: load new observations & causal graph
650
- demo.load(fn=get_mnist_obs, inputs=None, outputs=obs)
651
- demo.load(fn=mnist_graph, inputs=do, outputs=causal_graph)
652
- demo.load(fn=load_model, inputs=mnist_id, outputs=None)
653
- demo.load(fn=get_brain_obs, inputs=None, outputs=obs_brain)
 
 
654
  demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest)
 
 
655
 
656
- demo.load(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
657
  demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
658
 
659
  # on tab select: load models
660
- brain_tab.select(fn=load_model, inputs=brain_id, outputs=None)
661
- chest_tab.select(fn=load_model, inputs=chest_id, outputs=None)
662
 
663
  # "new" button: load new observations
664
- new.click(fn=get_mnist_obs, inputs=None, outputs=obs)
665
  new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest)
666
- new_brain.click(fn=get_brain_obs, inputs=None, outputs=obs_brain)
667
 
668
  # "new" button: reset causal graphs
669
- new.click(fn=mnist_graph, inputs=do, outputs=causal_graph)
670
- new_brain.click(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
671
  new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
672
 
673
  # "new" button: reset cf output panels
674
- for _k, _v in zip(
675
- [new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest]
676
- ):
677
- _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
 
 
678
 
679
  # "reset" button: reload current observations
680
- reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
681
- reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
682
  reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
683
 
684
  # "reset" button: deselect intervention checkboxes
685
- reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do)
686
- reset_brain.click(
687
- fn=lambda: (gr.update(value=False),) * len(do_brain),
688
- inputs=None,
689
- outputs=do_brain,
690
- )
691
  reset_chest.click(
692
  fn=lambda: (gr.update(value=False),) * len(do_chest),
693
  inputs=None,
@@ -695,21 +701,23 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
695
  )
696
 
697
  # "reset" button: reset cf output panels
698
- for _k, _v in zip(
699
- [reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest]
700
- ):
701
- _k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None)
702
- _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
703
-
704
- # enable mnist interventions when checkbox is selected & update graph
705
- for _k, _v in zip(do, [t, i, y]):
706
- _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
707
- _k.change(mnist_graph, inputs=do, outputs=causal_graph)
708
-
709
- # enable brain interventions when checkbox is selected & update graph
710
- for _k, _v in zip(do_brain, [m, s, a, b, v]):
711
- _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
712
- _k.change(brain_graph, inputs=do_brain, outputs=causal_graph_brain)
 
 
713
 
714
  # enable chest interventions when checkbox is selected & update graph
715
  for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
@@ -717,12 +725,12 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
717
  _k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
718
 
719
  # "submit" button: infer countefactuals
720
- submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
721
- submit_brain.click(
722
- fn=infer_brain_cf,
723
- inputs=obs_brain + do_brain,
724
- outputs=cf_out_brain + [m, s, a, b, v],
725
- )
726
  submit_chest.click(
727
  fn=infer_chest_cf,
728
  inputs=obs_chest + do_chest,
 
427
 
428
  with gr.Blocks(theme=gr.themes.Default()) as demo:
429
  with gr.Tabs():
430
+ # with gr.TabItem("Morpho-MNIST") as mnist_tab:
431
+ # mnist_id = gr.Textbox(value=mnist_tab.label, visible=False)
432
+ #
433
+ # with gr.Row(): #.style(equal_height=True):
434
+ # idx = gr.Number(value=0, visible=False)
435
+ # with gr.Column(scale=1, min_width=200):
436
+ # x = gr.Image(label="Observation", interactive=False, height=HEIGHT) #.style(
437
+ # #height=HEIGHT
438
+ # #)
439
+ # with gr.Column(scale=1, min_width=200):
440
+ # cf_x = gr.Image(label="Counterfactual", interactive=False, height=HEIGHT) #).style(
441
+ # # height=HEIGHT
442
+ # # )
443
+ # with gr.Column(scale=1, min_width=200):
444
+ # cf_x_std = gr.Image(
445
+ # label="Counterfactual Uncertainty", interactive=False
446
+ # , height=HEIGHT) #).style(height=HEIGHT)
447
+ # with gr.Column(scale=1, min_width=200):
448
+ # effect = gr.Image(
449
+ # label="Direct Causal Effect", interactive=False
450
+ # , height=HEIGHT) #).style(height=HEIGHT)
451
+ # with gr.Row(): #.style(equal_height=True):
452
+ # with gr.Column(scale=1):#.75):
453
+ # gr.Markdown(
454
+ # "**Intervention**"
455
+ # + 20 * " "
456
+ # + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
457
+ # + "  |   Hint: try 90% zoom"
458
+ # )
459
+ # with gr.Column():
460
+ # do_y = gr.Checkbox(label="do(digit)", value=False)
461
+ # y = gr.Radio(DIGITS, label="", interactive=False)
462
+ # with gr.Row():
463
+ # with gr.Column(min_width=100):
464
+ # do_t = gr.Checkbox(label="do(thickness)", value=False)
465
+ # t = gr.Slider(
466
+ # label="\u00A0",
467
+ # minimum=0.9,
468
+ # maximum=5.5,
469
+ # step=0.01,
470
+ # interactive=False,
471
+ # )
472
+ # with gr.Column(min_width=100):
473
+ # do_i = gr.Checkbox(label="do(intensity)", value=False)
474
+ # i = gr.Slider(
475
+ # label="\u00A0",
476
+ # minimum=50,
477
+ # maximum=255,
478
+ # step=0.01,
479
+ # interactive=False,
480
+ # )
481
+ # with gr.Row():
482
+ # new = gr.Button("New Observation")
483
+ # reset = gr.Button("Reset", variant="stop")
484
+ # submit = gr.Button("Submit", variant="primary")
485
+ # with gr.Column(scale=1):
486
+ # gr.Markdown("###  ")
487
+ # causal_graph = gr.Image(
488
+ # label="Causal Graph", interactive=False
489
+ # , height=300) #).style(height=300)
490
+ #
491
+ # with gr.TabItem("Brain MRI") as brain_tab:
492
+ # brain_id = gr.Textbox(value=brain_tab.label, visible=False)
493
+ #
494
+ # with gr.Row(): #.style(equal_height=True):
495
+ # idx_brain = gr.Number(value=0, visible=False)
496
+ # with gr.Column(scale=1, min_width=200):
497
+ # x_brain = gr.Image(label="Observation", interactive=False, height=HEIGHT) #).style(
498
+ # # height=HEIGHT
499
+ # # )
500
+ # with gr.Column(scale=1, min_width=200):
501
+ # cf_x_brain = gr.Image(
502
+ # label="Counterfactual", interactive=False
503
+ # , height=HEIGHT) #).style(height=HEIGHT)
504
+ # with gr.Column(scale=1, min_width=200):
505
+ # cf_x_std_brain = gr.Image(
506
+ # label="Counterfactual Uncertainty", interactive=False
507
+ # , height=HEIGHT) #).style(height=HEIGHT)
508
+ # with gr.Column(scale=1, min_width=200):
509
+ # effect_brain = gr.Image(
510
+ # label="Direct Causal Effect", interactive=False
511
+ # , height=HEIGHT) #).style(height=HEIGHT)
512
+ # with gr.Row():
513
+ # with gr.Column(scale=2):#.55):
514
+ # gr.Markdown(
515
+ # "**Intervention**"
516
+ # + 20 * " "
517
+ # + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
518
+ # + "  |   Hint: try 90% zoom"
519
+ # )
520
+ # with gr.Row():
521
+ # with gr.Column(min_width=200):
522
+ # do_m = gr.Checkbox(label="do(MRI sequence)", value=False)
523
+ # m = gr.Radio(
524
+ # ["T1", "T2-FLAIR"], label="", interactive=False
525
+ # )
526
+ # with gr.Column(min_width=200):
527
+ # do_s = gr.Checkbox(label="do(sex)", value=False)
528
+ # s = gr.Radio(
529
+ # ["female", "male"], label="", interactive=False
530
+ # )
531
+ # with gr.Row():
532
+ # with gr.Column(min_width=100):
533
+ # do_a = gr.Checkbox(label="do(age)", value=False)
534
+ # a = gr.Slider(
535
+ # label="\u00A0",
536
+ # value=50,
537
+ # minimum=44,
538
+ # maximum=73,
539
+ # step=1,
540
+ # interactive=False,
541
+ # )
542
+ # with gr.Column(min_width=100):
543
+ # do_b = gr.Checkbox(label="do(brain volume)", value=False)
544
+ # b = gr.Slider(
545
+ # label="\u00A0",
546
+ # value=1000,
547
+ # minimum=850,
548
+ # maximum=1550,
549
+ # step=20,
550
+ # interactive=False,
551
+ # )
552
+ # with gr.Column(min_width=100):
553
+ # do_v = gr.Checkbox(
554
+ # label="do(ventricle volume)", value=False
555
+ # )
556
+ # v = gr.Slider(
557
+ # label="\u00A0",
558
+ # value=40,
559
+ # minimum=10,
560
+ # maximum=125,
561
+ # step=2,
562
+ # interactive=False,
563
+ # )
564
+ # with gr.Row():
565
+ # new_brain = gr.Button("New Observation")
566
+ # reset_brain = gr.Button("Reset", variant="stop")
567
+ # submit_brain = gr.Button("Submit", variant="primary")
568
+ # with gr.Column(scale=1):
569
+ # # gr.Markdown("###  ")
570
+ # causal_graph_brain = gr.Image(
571
+ # label="Causal Graph", interactive=False
572
+ # , height=340) #).style(height=340)
573
 
574
  with gr.TabItem("Chest X-ray") as chest_tab:
575
  chest_id = gr.Textbox(value=chest_tab.label, visible=False)
 
647
  cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest]
648
 
649
  # on start: load new observations & causal graph
650
+ # demo.load(fn=get_mnist_obs, inputs=None, outputs=obs)
651
+ # demo.load(fn=mnist_graph, inputs=do, outputs=causal_graph)
652
+ # demo.load(fn=load_model, inputs=mnist_id, outputs=None)
653
+ # demo.load(fn=get_brain_obs, inputs=None, outputs=obs_brain)
654
+ # demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest)
655
+
656
  demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest)
657
+ demo.load(fn=load_model, inputs=chest_id, output=None)
658
+
659
 
660
+ # demo.load(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
661
  demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
662
 
663
  # on tab select: load models
664
+ # brain_tab.select(fn=load_model, inputs=brain_id, outputs=None)
665
+ # chest_tab.select(fn=load_model, inputs=chest_id, outputs=None)
666
 
667
  # "new" button: load new observations
668
+ # new.click(fn=get_mnist_obs, inputs=None, outputs=obs)
669
  new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest)
670
+ # new_brain.click(fn=get_brain_obs, inputs=None, outputs=obs_brain)
671
 
672
  # "new" button: reset causal graphs
673
+ # new.click(fn=mnist_graph, inputs=do, outputs=causal_graph)
674
+ # new_brain.click(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
675
  new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
676
 
677
  # "new" button: reset cf output panels
678
+ # for _k, _v in zip(
679
+ # [new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest]
680
+ # ):
681
+ # _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
682
+ new_chest.click(fn=lambda:(gr.update(value=None),) * 3, inputs=None, outputs=cf_out_chest)
683
+
684
 
685
  # "reset" button: reload current observations
686
+ # reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
687
+ # reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
688
  reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
689
 
690
  # "reset" button: deselect intervention checkboxes
691
+ # reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do)
692
+ # reset_brain.click(
693
+ # fn=lambda: (gr.update(value=False),) * len(do_brain),
694
+ # inputs=None,
695
+ # outputs=do_brain,
696
+ # )
697
  reset_chest.click(
698
  fn=lambda: (gr.update(value=False),) * len(do_chest),
699
  inputs=None,
 
701
  )
702
 
703
  # "reset" button: reset cf output panels
704
+ # for _k, _v in zip(
705
+ # [reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest]
706
+ # ):
707
+ # _k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None)
708
+ # _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
709
+ reset_chest.lick(fn=lambda: plt.close("all"), inputs=None, outputs=None)
710
+ reset_chest.lick(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=cf_out_chest)
711
+
712
+ # # enable mnist interventions when checkbox is selected & update graph
713
+ # for _k, _v in zip(do, [t, i, y]):
714
+ # _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
715
+ # _k.change(mnist_graph, inputs=do, outputs=causal_graph)
716
+ #
717
+ # # enable brain interventions when checkbox is selected & update graph
718
+ # for _k, _v in zip(do_brain, [m, s, a, b, v]):
719
+ # _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
720
+ # _k.change(brain_graph, inputs=do_brain, outputs=causal_graph_brain)
721
 
722
  # enable chest interventions when checkbox is selected & update graph
723
  for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
 
725
  _k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
726
 
727
  # "submit" button: infer countefactuals
728
+ # submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
729
+ # submit_brain.click(
730
+ # fn=infer_brain_cf,
731
+ # inputs=obs_brain + do_brain,
732
+ # outputs=cf_out_brain + [m, s, a, b, v],
733
+ # )
734
  submit_chest.click(
735
  fn=infer_chest_cf,
736
  inputs=obs_chest + do_chest,