Upload app.py
Browse files
app.py
CHANGED
@@ -427,149 +427,149 @@ def infer_chest_cf(*args):
|
|
427 |
|
428 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
429 |
with gr.Tabs():
|
430 |
-
with gr.TabItem("Morpho-MNIST") as mnist_tab:
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
with gr.TabItem("Brain MRI") as brain_tab:
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
|
574 |
with gr.TabItem("Chest X-ray") as chest_tab:
|
575 |
chest_id = gr.Textbox(value=chest_tab.label, visible=False)
|
@@ -647,47 +647,53 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
647 |
cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest]
|
648 |
|
649 |
# on start: load new observations & causal graph
|
650 |
-
demo.load(fn=get_mnist_obs, inputs=None, outputs=obs)
|
651 |
-
demo.load(fn=mnist_graph, inputs=do, outputs=causal_graph)
|
652 |
-
demo.load(fn=load_model, inputs=mnist_id, outputs=None)
|
653 |
-
demo.load(fn=get_brain_obs, inputs=None, outputs=obs_brain)
|
|
|
|
|
654 |
demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest)
|
|
|
|
|
655 |
|
656 |
-
demo.load(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
|
657 |
demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
658 |
|
659 |
# on tab select: load models
|
660 |
-
brain_tab.select(fn=load_model, inputs=brain_id, outputs=None)
|
661 |
-
chest_tab.select(fn=load_model, inputs=chest_id, outputs=None)
|
662 |
|
663 |
# "new" button: load new observations
|
664 |
-
new.click(fn=get_mnist_obs, inputs=None, outputs=obs)
|
665 |
new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest)
|
666 |
-
new_brain.click(fn=get_brain_obs, inputs=None, outputs=obs_brain)
|
667 |
|
668 |
# "new" button: reset causal graphs
|
669 |
-
new.click(fn=mnist_graph, inputs=do, outputs=causal_graph)
|
670 |
-
new_brain.click(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
|
671 |
new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
672 |
|
673 |
# "new" button: reset cf output panels
|
674 |
-
for _k, _v in zip(
|
675 |
-
|
676 |
-
):
|
677 |
-
|
|
|
|
|
678 |
|
679 |
# "reset" button: reload current observations
|
680 |
-
reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
|
681 |
-
reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
|
682 |
reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
|
683 |
|
684 |
# "reset" button: deselect intervention checkboxes
|
685 |
-
reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do)
|
686 |
-
reset_brain.click(
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
)
|
691 |
reset_chest.click(
|
692 |
fn=lambda: (gr.update(value=False),) * len(do_chest),
|
693 |
inputs=None,
|
@@ -695,21 +701,23 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
695 |
)
|
696 |
|
697 |
# "reset" button: reset cf output panels
|
698 |
-
for _k, _v in zip(
|
699 |
-
|
700 |
-
):
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
#
|
710 |
-
|
711 |
-
|
712 |
-
|
|
|
|
|
713 |
|
714 |
# enable chest interventions when checkbox is selected & update graph
|
715 |
for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
|
@@ -717,12 +725,12 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
717 |
_k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
718 |
|
719 |
# "submit" button: infer countefactuals
|
720 |
-
submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
|
721 |
-
submit_brain.click(
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
)
|
726 |
submit_chest.click(
|
727 |
fn=infer_chest_cf,
|
728 |
inputs=obs_chest + do_chest,
|
|
|
427 |
|
428 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
429 |
with gr.Tabs():
|
430 |
+
# with gr.TabItem("Morpho-MNIST") as mnist_tab:
|
431 |
+
# mnist_id = gr.Textbox(value=mnist_tab.label, visible=False)
|
432 |
+
#
|
433 |
+
# with gr.Row(): #.style(equal_height=True):
|
434 |
+
# idx = gr.Number(value=0, visible=False)
|
435 |
+
# with gr.Column(scale=1, min_width=200):
|
436 |
+
# x = gr.Image(label="Observation", interactive=False, height=HEIGHT) #.style(
|
437 |
+
# #height=HEIGHT
|
438 |
+
# #)
|
439 |
+
# with gr.Column(scale=1, min_width=200):
|
440 |
+
# cf_x = gr.Image(label="Counterfactual", interactive=False, height=HEIGHT) #).style(
|
441 |
+
# # height=HEIGHT
|
442 |
+
# # )
|
443 |
+
# with gr.Column(scale=1, min_width=200):
|
444 |
+
# cf_x_std = gr.Image(
|
445 |
+
# label="Counterfactual Uncertainty", interactive=False
|
446 |
+
# , height=HEIGHT) #).style(height=HEIGHT)
|
447 |
+
# with gr.Column(scale=1, min_width=200):
|
448 |
+
# effect = gr.Image(
|
449 |
+
# label="Direct Causal Effect", interactive=False
|
450 |
+
# , height=HEIGHT) #).style(height=HEIGHT)
|
451 |
+
# with gr.Row(): #.style(equal_height=True):
|
452 |
+
# with gr.Column(scale=1):#.75):
|
453 |
+
# gr.Markdown(
|
454 |
+
# "**Intervention**"
|
455 |
+
# + 20 * " "
|
456 |
+
# + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
|
457 |
+
# + "  |   Hint: try 90% zoom"
|
458 |
+
# )
|
459 |
+
# with gr.Column():
|
460 |
+
# do_y = gr.Checkbox(label="do(digit)", value=False)
|
461 |
+
# y = gr.Radio(DIGITS, label="", interactive=False)
|
462 |
+
# with gr.Row():
|
463 |
+
# with gr.Column(min_width=100):
|
464 |
+
# do_t = gr.Checkbox(label="do(thickness)", value=False)
|
465 |
+
# t = gr.Slider(
|
466 |
+
# label="\u00A0",
|
467 |
+
# minimum=0.9,
|
468 |
+
# maximum=5.5,
|
469 |
+
# step=0.01,
|
470 |
+
# interactive=False,
|
471 |
+
# )
|
472 |
+
# with gr.Column(min_width=100):
|
473 |
+
# do_i = gr.Checkbox(label="do(intensity)", value=False)
|
474 |
+
# i = gr.Slider(
|
475 |
+
# label="\u00A0",
|
476 |
+
# minimum=50,
|
477 |
+
# maximum=255,
|
478 |
+
# step=0.01,
|
479 |
+
# interactive=False,
|
480 |
+
# )
|
481 |
+
# with gr.Row():
|
482 |
+
# new = gr.Button("New Observation")
|
483 |
+
# reset = gr.Button("Reset", variant="stop")
|
484 |
+
# submit = gr.Button("Submit", variant="primary")
|
485 |
+
# with gr.Column(scale=1):
|
486 |
+
# gr.Markdown("### ")
|
487 |
+
# causal_graph = gr.Image(
|
488 |
+
# label="Causal Graph", interactive=False
|
489 |
+
# , height=300) #).style(height=300)
|
490 |
+
#
|
491 |
+
# with gr.TabItem("Brain MRI") as brain_tab:
|
492 |
+
# brain_id = gr.Textbox(value=brain_tab.label, visible=False)
|
493 |
+
#
|
494 |
+
# with gr.Row(): #.style(equal_height=True):
|
495 |
+
# idx_brain = gr.Number(value=0, visible=False)
|
496 |
+
# with gr.Column(scale=1, min_width=200):
|
497 |
+
# x_brain = gr.Image(label="Observation", interactive=False, height=HEIGHT) #).style(
|
498 |
+
# # height=HEIGHT
|
499 |
+
# # )
|
500 |
+
# with gr.Column(scale=1, min_width=200):
|
501 |
+
# cf_x_brain = gr.Image(
|
502 |
+
# label="Counterfactual", interactive=False
|
503 |
+
# , height=HEIGHT) #).style(height=HEIGHT)
|
504 |
+
# with gr.Column(scale=1, min_width=200):
|
505 |
+
# cf_x_std_brain = gr.Image(
|
506 |
+
# label="Counterfactual Uncertainty", interactive=False
|
507 |
+
# , height=HEIGHT) #).style(height=HEIGHT)
|
508 |
+
# with gr.Column(scale=1, min_width=200):
|
509 |
+
# effect_brain = gr.Image(
|
510 |
+
# label="Direct Causal Effect", interactive=False
|
511 |
+
# , height=HEIGHT) #).style(height=HEIGHT)
|
512 |
+
# with gr.Row():
|
513 |
+
# with gr.Column(scale=2):#.55):
|
514 |
+
# gr.Markdown(
|
515 |
+
# "**Intervention**"
|
516 |
+
# + 20 * " "
|
517 |
+
# + "[arXiv paper](https://arxiv.org/abs/2306.15764)   |   [GitHub code](https://github.com/biomedia-mira/causal-gen)"
|
518 |
+
# + "  |   Hint: try 90% zoom"
|
519 |
+
# )
|
520 |
+
# with gr.Row():
|
521 |
+
# with gr.Column(min_width=200):
|
522 |
+
# do_m = gr.Checkbox(label="do(MRI sequence)", value=False)
|
523 |
+
# m = gr.Radio(
|
524 |
+
# ["T1", "T2-FLAIR"], label="", interactive=False
|
525 |
+
# )
|
526 |
+
# with gr.Column(min_width=200):
|
527 |
+
# do_s = gr.Checkbox(label="do(sex)", value=False)
|
528 |
+
# s = gr.Radio(
|
529 |
+
# ["female", "male"], label="", interactive=False
|
530 |
+
# )
|
531 |
+
# with gr.Row():
|
532 |
+
# with gr.Column(min_width=100):
|
533 |
+
# do_a = gr.Checkbox(label="do(age)", value=False)
|
534 |
+
# a = gr.Slider(
|
535 |
+
# label="\u00A0",
|
536 |
+
# value=50,
|
537 |
+
# minimum=44,
|
538 |
+
# maximum=73,
|
539 |
+
# step=1,
|
540 |
+
# interactive=False,
|
541 |
+
# )
|
542 |
+
# with gr.Column(min_width=100):
|
543 |
+
# do_b = gr.Checkbox(label="do(brain volume)", value=False)
|
544 |
+
# b = gr.Slider(
|
545 |
+
# label="\u00A0",
|
546 |
+
# value=1000,
|
547 |
+
# minimum=850,
|
548 |
+
# maximum=1550,
|
549 |
+
# step=20,
|
550 |
+
# interactive=False,
|
551 |
+
# )
|
552 |
+
# with gr.Column(min_width=100):
|
553 |
+
# do_v = gr.Checkbox(
|
554 |
+
# label="do(ventricle volume)", value=False
|
555 |
+
# )
|
556 |
+
# v = gr.Slider(
|
557 |
+
# label="\u00A0",
|
558 |
+
# value=40,
|
559 |
+
# minimum=10,
|
560 |
+
# maximum=125,
|
561 |
+
# step=2,
|
562 |
+
# interactive=False,
|
563 |
+
# )
|
564 |
+
# with gr.Row():
|
565 |
+
# new_brain = gr.Button("New Observation")
|
566 |
+
# reset_brain = gr.Button("Reset", variant="stop")
|
567 |
+
# submit_brain = gr.Button("Submit", variant="primary")
|
568 |
+
# with gr.Column(scale=1):
|
569 |
+
# # gr.Markdown("### ")
|
570 |
+
# causal_graph_brain = gr.Image(
|
571 |
+
# label="Causal Graph", interactive=False
|
572 |
+
# , height=340) #).style(height=340)
|
573 |
|
574 |
with gr.TabItem("Chest X-ray") as chest_tab:
|
575 |
chest_id = gr.Textbox(value=chest_tab.label, visible=False)
|
|
|
647 |
cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest]
|
648 |
|
649 |
# on start: load new observations & causal graph
|
650 |
+
# demo.load(fn=get_mnist_obs, inputs=None, outputs=obs)
|
651 |
+
# demo.load(fn=mnist_graph, inputs=do, outputs=causal_graph)
|
652 |
+
# demo.load(fn=load_model, inputs=mnist_id, outputs=None)
|
653 |
+
# demo.load(fn=get_brain_obs, inputs=None, outputs=obs_brain)
|
654 |
+
# demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest)
|
655 |
+
|
656 |
demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest)
|
657 |
+
demo.load(fn=load_model, inputs=chest_id, output=None)
|
658 |
+
|
659 |
|
660 |
+
# demo.load(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
|
661 |
demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
662 |
|
663 |
# on tab select: load models
|
664 |
+
# brain_tab.select(fn=load_model, inputs=brain_id, outputs=None)
|
665 |
+
# chest_tab.select(fn=load_model, inputs=chest_id, outputs=None)
|
666 |
|
667 |
# "new" button: load new observations
|
668 |
+
# new.click(fn=get_mnist_obs, inputs=None, outputs=obs)
|
669 |
new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest)
|
670 |
+
# new_brain.click(fn=get_brain_obs, inputs=None, outputs=obs_brain)
|
671 |
|
672 |
# "new" button: reset causal graphs
|
673 |
+
# new.click(fn=mnist_graph, inputs=do, outputs=causal_graph)
|
674 |
+
# new_brain.click(fn=brain_graph, inputs=do_brain, outputs=causal_graph_brain)
|
675 |
new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
676 |
|
677 |
# "new" button: reset cf output panels
|
678 |
+
# for _k, _v in zip(
|
679 |
+
# [new, new_brain, new_chest], [cf_out, cf_out_brain, cf_out_chest]
|
680 |
+
# ):
|
681 |
+
# _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
|
682 |
+
new_chest.click(fn=lambda:(gr.update(value=None),) * 3, inputs=None, outputs=cf_out_chest)
|
683 |
+
|
684 |
|
685 |
# "reset" button: reload current observations
|
686 |
+
# reset.click(fn=get_mnist_obs, inputs=idx, outputs=obs)
|
687 |
+
# reset_brain.click(fn=get_brain_obs, inputs=idx_brain, outputs=obs_brain)
|
688 |
reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest)
|
689 |
|
690 |
# "reset" button: deselect intervention checkboxes
|
691 |
+
# reset.click(fn=lambda: (gr.update(value=False),) * len(do), inputs=None, outputs=do)
|
692 |
+
# reset_brain.click(
|
693 |
+
# fn=lambda: (gr.update(value=False),) * len(do_brain),
|
694 |
+
# inputs=None,
|
695 |
+
# outputs=do_brain,
|
696 |
+
# )
|
697 |
reset_chest.click(
|
698 |
fn=lambda: (gr.update(value=False),) * len(do_chest),
|
699 |
inputs=None,
|
|
|
701 |
)
|
702 |
|
703 |
# "reset" button: reset cf output panels
|
704 |
+
# for _k, _v in zip(
|
705 |
+
# [reset, reset_brain, reset_chest], [cf_out, cf_out_brain, cf_out_chest]
|
706 |
+
# ):
|
707 |
+
# _k.click(fn=lambda: plt.close("all"), inputs=None, outputs=None)
|
708 |
+
# _k.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=_v)
|
709 |
+
reset_chest.lick(fn=lambda: plt.close("all"), inputs=None, outputs=None)
|
710 |
+
reset_chest.lick(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=cf_out_chest)
|
711 |
+
|
712 |
+
# # enable mnist interventions when checkbox is selected & update graph
|
713 |
+
# for _k, _v in zip(do, [t, i, y]):
|
714 |
+
# _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
|
715 |
+
# _k.change(mnist_graph, inputs=do, outputs=causal_graph)
|
716 |
+
#
|
717 |
+
# # enable brain interventions when checkbox is selected & update graph
|
718 |
+
# for _k, _v in zip(do_brain, [m, s, a, b, v]):
|
719 |
+
# _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v)
|
720 |
+
# _k.change(brain_graph, inputs=do_brain, outputs=causal_graph_brain)
|
721 |
|
722 |
# enable chest interventions when checkbox is selected & update graph
|
723 |
for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]):
|
|
|
725 |
_k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest)
|
726 |
|
727 |
# "submit" button: infer countefactuals
|
728 |
+
# submit.click(fn=infer_mnist_cf, inputs=obs + do, outputs=cf_out + [t, i, y])
|
729 |
+
# submit_brain.click(
|
730 |
+
# fn=infer_brain_cf,
|
731 |
+
# inputs=obs_brain + do_brain,
|
732 |
+
# outputs=cf_out_brain + [m, s, a, b, v],
|
733 |
+
# )
|
734 |
submit_chest.click(
|
735 |
fn=infer_chest_cf,
|
736 |
inputs=obs_chest + do_chest,
|