|
import torch |
|
import numpy as np |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
|
|
from PIL import Image |
|
|
|
from matplotlib import rc, patches, colors |
|
|
|
rc("font", **{"family": "serif", "serif": ["Roman"]}) |
|
rc("text", usetex=True) |
|
rc("image", interpolation="none") |
|
rc("text.latex", preamble=r"\usepackage{amsmath} \usepackage{amssymb}") |
|
|
|
from datasets import get_attr_max_min |
|
|
|
HAMMER = np.array(Image.open("./hammer.png").resize((35, 35))) / 255 |
|
|
|
|
|
class MidpointNormalize(colors.Normalize): |
|
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): |
|
self.midpoint = midpoint |
|
colors.Normalize.__init__(self, vmin, vmax, clip) |
|
|
|
def __call__(self, value, clip=None): |
|
v_ext = np.max([np.abs(self.vmin), np.abs(self.vmax)]) |
|
x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1] |
|
return np.ma.masked_array(np.interp(value, x, y)) |
|
|
|
|
|
def postprocess(x): |
|
return ((x + 1.0) * 127.5).squeeze().detach().cpu().numpy() |
|
|
|
|
|
def mnist_graph(*args): |
|
x, t, i, y = r"$\mathbf{x}$", r"$t$", r"$i$", r"$y$" |
|
ut, ui, uy = r"$\mathbf{U}_t$", r"$\mathbf{U}_i$", r"$\mathbf{U}_y$" |
|
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" |
|
|
|
G = nx.DiGraph() |
|
G.add_edge(t, x) |
|
G.add_edge(i, x) |
|
G.add_edge(y, x) |
|
G.add_edge(t, i) |
|
G.add_edge(ut, t) |
|
G.add_edge(ui, i) |
|
G.add_edge(uy, y) |
|
G.add_edge(zx, x) |
|
G.add_edge(ex, x) |
|
|
|
pos = { |
|
y: (0, 0), |
|
uy: (-1, 0), |
|
t: (0, 0.5), |
|
ut: (0, 1), |
|
x: (1, 0), |
|
zx: (2, 0.375), |
|
ex: (2, 0), |
|
i: (1, 0.5), |
|
ui: (1, 1), |
|
} |
|
|
|
node_c = {} |
|
for node in G: |
|
node_c[node] = "lightgrey" if node in [x, t, i, y] else "white" |
|
node_line_c = {k: "black" for k, _ in node_c.items()} |
|
edge_c = {e: "black" for e in G.edges} |
|
|
|
if args[0]: |
|
edge_c[(ut, t)] = "lightgrey" |
|
|
|
node_line_c[t] = "red" |
|
if args[1]: |
|
edge_c[(ui, i)] = "lightgrey" |
|
edge_c[(t, i)] = "lightgrey" |
|
|
|
node_line_c[i] = "red" |
|
if args[2]: |
|
edge_c[(uy, y)] = "lightgrey" |
|
|
|
node_line_c[y] = "red" |
|
|
|
fs = 30 |
|
options = { |
|
"font_size": fs, |
|
"node_size": 3000, |
|
"node_color": list(node_c.values()), |
|
"edgecolors": list(node_line_c.values()), |
|
"edge_color": list(edge_c.values()), |
|
"linewidths": 2, |
|
"width": 2, |
|
} |
|
plt.close("all") |
|
fig, ax = plt.subplots(1, 1, figsize=(6, 4.1)) |
|
|
|
ax.margins(x=0.06, y=0.15, tight=False) |
|
ax.axis("off") |
|
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) |
|
|
|
x_lim = (-1.348, 2.348) |
|
y_lim = (-0.215, 1.215) |
|
ax.set_xlim(x_lim) |
|
ax.set_ylim(y_lim) |
|
rect = patches.FancyBboxPatch( |
|
(1.75, -0.16), |
|
0.5, |
|
0.7, |
|
boxstyle="round, pad=0.05, rounding_size=0", |
|
linewidth=2, |
|
edgecolor="black", |
|
facecolor="none", |
|
linestyle="-", |
|
) |
|
ax.add_patch(rect) |
|
ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs) |
|
|
|
if args[0]: |
|
fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=10) |
|
if args[1]: |
|
fig.figimage(HAMMER, 0.5175 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=11) |
|
if args[2]: |
|
fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.2 * fig.bbox.ymax, zorder=12) |
|
|
|
fig.tight_layout() |
|
fig.canvas.draw() |
|
return np.array(fig.canvas.renderer.buffer_rgba()) |
|
|
|
|
|
def brain_graph(*args): |
|
x, m, s, a, b, v = r"$\mathbf{x}$", r"$m$", r"$s$", r"$a$", r"$b$", r"$v$" |
|
um, us, ua, ub, uv = ( |
|
r"$\mathbf{U}_m$", |
|
r"$\mathbf{U}_s$", |
|
r"$\mathbf{U}_a$", |
|
r"$\mathbf{U}_b$", |
|
r"$\mathbf{U}_v$", |
|
) |
|
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" |
|
|
|
G = nx.DiGraph() |
|
G.add_edge(m, x) |
|
G.add_edge(s, x) |
|
G.add_edge(b, x) |
|
G.add_edge(v, x) |
|
G.add_edge(zx, x) |
|
G.add_edge(ex, x) |
|
G.add_edge(a, b) |
|
G.add_edge(a, v) |
|
G.add_edge(s, b) |
|
G.add_edge(um, m) |
|
G.add_edge(us, s) |
|
G.add_edge(ua, a) |
|
G.add_edge(ub, b) |
|
G.add_edge(uv, v) |
|
|
|
pos = { |
|
x: (0, 0), |
|
zx: (-0.25, -1), |
|
ex: (0.25, -1), |
|
a: (0, 1), |
|
ua: (0, 2), |
|
s: (1, 0), |
|
us: (1, -1), |
|
b: (1, 1), |
|
ub: (1, 2), |
|
m: (-1, 0), |
|
um: (-1, -1), |
|
v: (-1, 1), |
|
uv: (-1, 2), |
|
} |
|
|
|
node_c = {} |
|
for node in G: |
|
node_c[node] = "lightgrey" if node in [x, m, s, a, b, v] else "white" |
|
node_line_c = {k: "black" for k, _ in node_c.items()} |
|
edge_c = {e: "black" for e in G.edges} |
|
|
|
if args[0]: |
|
|
|
edge_c[(um, m)] = "lightgrey" |
|
node_line_c[m] = "red" |
|
if args[1]: |
|
|
|
edge_c[(us, s)] = "lightgrey" |
|
node_line_c[s] = "red" |
|
if args[2]: |
|
|
|
edge_c[(ua, a)] = "lightgrey" |
|
node_line_c[a] = "red" |
|
if args[3]: |
|
|
|
edge_c[(ub, b)] = "lightgrey" |
|
edge_c[(s, b)] = "lightgrey" |
|
edge_c[(a, b)] = "lightgrey" |
|
node_line_c[b] = "red" |
|
if args[4]: |
|
|
|
edge_c[(uv, v)] = "lightgrey" |
|
edge_c[(a, v)] = "lightgrey" |
|
edge_c[(b, v)] = "lightgrey" |
|
node_line_c[v] = "red" |
|
|
|
fs = 30 |
|
options = { |
|
"font_size": fs, |
|
"node_size": 3000, |
|
"node_color": list(node_c.values()), |
|
"edgecolors": list(node_line_c.values()), |
|
"edge_color": list(edge_c.values()), |
|
"linewidths": 2, |
|
"width": 2, |
|
} |
|
|
|
plt.close("all") |
|
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) |
|
|
|
ax.margins(x=0.1, y=0.08, tight=False) |
|
ax.axis("off") |
|
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) |
|
|
|
x_lim = (-1.32, 1.32) |
|
y_lim = (-1.414, 2.414) |
|
ax.set_xlim(x_lim) |
|
ax.set_ylim(y_lim) |
|
rect = patches.FancyBboxPatch( |
|
(-0.5, -1.325), |
|
1, |
|
0.65, |
|
boxstyle="round, pad=0.05, rounding_size=0", |
|
linewidth=2, |
|
edgecolor="black", |
|
facecolor="none", |
|
linestyle="-", |
|
) |
|
ax.add_patch(rect) |
|
|
|
|
|
if args[0]: |
|
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=10) |
|
if args[1]: |
|
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11) |
|
if args[2]: |
|
fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12) |
|
if args[3]: |
|
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13) |
|
if args[4]: |
|
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=14) |
|
else: |
|
a3 = patches.FancyArrowPatch( |
|
(0.86, 1.21), |
|
(-0.86, 1.21), |
|
connectionstyle="arc3,rad=.3", |
|
linewidth=2, |
|
arrowstyle="simple, head_width=10, head_length=10", |
|
color="k", |
|
) |
|
ax.add_patch(a3) |
|
|
|
|
|
fig.tight_layout() |
|
fig.canvas.draw() |
|
return np.array(fig.canvas.renderer.buffer_rgba()) |
|
|
|
|
|
def chest_graph(*args): |
|
x, a, d, r, s = r"$\mathbf{x}$", r"$a$", r"$d$", r"$r$", r"$s$" |
|
ua, ud, ur, us = ( |
|
r"$\mathbf{U}_a$", |
|
r"$\mathbf{U}_d$", |
|
r"$\mathbf{U}_r$", |
|
r"$\mathbf{U}_s$", |
|
) |
|
zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" |
|
|
|
G = nx.DiGraph() |
|
G.add_edge(ua, a) |
|
G.add_edge(ud, d) |
|
G.add_edge(ur, r) |
|
G.add_edge(us, s) |
|
G.add_edge(a, d) |
|
G.add_edge(d, x) |
|
G.add_edge(r, x) |
|
G.add_edge(s, x) |
|
G.add_edge(ex, x) |
|
G.add_edge(zx, x) |
|
G.add_edge(a, x) |
|
|
|
pos = { |
|
x: (0, 0), |
|
a: (-1, 1), |
|
d: (0, 1), |
|
r: (1, 1), |
|
s: (1, 0), |
|
ua: (-1, 2), |
|
ud: (0, 2), |
|
ur: (1, 2), |
|
us: (1, -1), |
|
zx: (-0.25, -1), |
|
ex: (0.25, -1), |
|
} |
|
|
|
node_c = {} |
|
for node in G: |
|
node_c[node] = "lightgrey" if node in [x, a, d, r, s] else "white" |
|
|
|
edge_c = {e: "black" for e in G.edges} |
|
node_line_c = {k: "black" for k, _ in node_c.items()} |
|
|
|
if args[0]: |
|
|
|
edge_c[(ur, r)] = "lightgrey" |
|
node_line_c[r] = "red" |
|
if args[1]: |
|
|
|
edge_c[(us, s)] = "lightgrey" |
|
node_line_c[s] = "red" |
|
if args[2]: |
|
|
|
edge_c[(ud, d)] = "lightgrey" |
|
edge_c[(a, d)] = "lightgrey" |
|
node_line_c[d] = "red" |
|
if args[3]: |
|
|
|
edge_c[(ua, a)] = "lightgrey" |
|
node_line_c[a] = "red" |
|
|
|
fs = 30 |
|
options = { |
|
"font_size": fs, |
|
"node_size": 3000, |
|
"node_color": list(node_c.values()), |
|
"edgecolors": list(node_line_c.values()), |
|
"edge_color": list(edge_c.values()), |
|
"linewidths": 2, |
|
"width": 2, |
|
} |
|
plt.close("all") |
|
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) |
|
|
|
ax.margins(x=0.1, y=0.08, tight=False) |
|
ax.axis("off") |
|
nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) |
|
|
|
x_lim = (-1.32, 1.32) |
|
y_lim = (-1.414, 2.414) |
|
ax.set_xlim(x_lim) |
|
ax.set_ylim(y_lim) |
|
rect = patches.FancyBboxPatch( |
|
(-0.5, -1.325), |
|
1, |
|
0.65, |
|
boxstyle="round, pad=0.05, rounding_size=0", |
|
linewidth=2, |
|
edgecolor="black", |
|
facecolor="none", |
|
linestyle="-", |
|
) |
|
ax.add_patch(rect) |
|
ax.text(-0.9, -1.075, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs) |
|
|
|
if args[0]: |
|
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=10) |
|
if args[1]: |
|
fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11) |
|
if args[2]: |
|
fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12) |
|
if args[3]: |
|
fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13) |
|
|
|
fig.tight_layout() |
|
fig.canvas.draw() |
|
return np.array(fig.canvas.renderer.buffer_rgba()) |
|
|
|
|
|
def vae_preprocess(args, pa): |
|
if "ukbb" in args.hps: |
|
|
|
|
|
|
|
for k, v in pa.items(): |
|
if k != "mri_seq" and k != "sex": |
|
pa[k] = (v + 1) / 2 |
|
_max, _min = get_attr_max_min(k) |
|
pa[k] = pa[k] * (_max - _min) + _min |
|
|
|
for k, v in pa.items(): |
|
logpa_k = torch.log(v.clamp(min=1e-12)) |
|
if k == "age": |
|
pa[k] = (logpa_k - 4.112339973449707) / 0.11769197136163712 |
|
elif k == "brain_volume": |
|
pa[k] = (logpa_k - 13.965583801269531) / 0.09537758678197861 |
|
elif k == "ventricle_volume": |
|
pa[k] = (logpa_k - 10.345998764038086) / 0.43127763271331787 |
|
|
|
pa = torch.cat( |
|
[pa[k] if len(pa[k].shape) > 1 else pa[k][..., None] for k in args.parents_x], |
|
dim=1, |
|
) |
|
pa = ( |
|
pa[..., None, None].repeat(1, 1, *(args.input_res,) * 2).to(args.device).float() |
|
) |
|
return pa |
|
|
|
|
|
def preprocess_brain(args, obs): |
|
obs["x"] = (obs["x"][None, ...].float().to(args.device) - 127.5) / 127.5 |
|
|
|
for k in [k for k in obs.keys() if k != "x"]: |
|
obs[k] = obs[k].float().to(args.device).view(1, 1) |
|
if k in ["age", "brain_volume", "ventricle_volume"]: |
|
k_max, k_min = get_attr_max_min(k) |
|
obs[k] = (obs[k] - k_min) / (k_max - k_min) |
|
obs[k] = 2 * obs[k] - 1 |
|
return obs |
|
|
|
|
|
def get_fig_arr(x, width=4, height=4, dpi=144, cmap="Greys_r", norm=None): |
|
fig = plt.figure(figsize=(width, height), dpi=dpi) |
|
ax = plt.axes([0, 0, 1, 1], frameon=False) |
|
if cmap == "Greys_r": |
|
ax.imshow(x, cmap=cmap, vmin=0, vmax=255) |
|
else: |
|
ax.imshow(x, cmap=cmap, norm=norm) |
|
ax.axis("off") |
|
fig.canvas.draw() |
|
return np.array(fig.canvas.renderer.buffer_rgba()) |
|
|
|
|
|
def normalize(x, x_min=None, x_max=None, zero_one=False): |
|
if x_min is None: |
|
x_min = x.min() |
|
if x_max is None: |
|
x_max = x.max() |
|
x = (x - x_min) / (x_max - x_min) |
|
return x if zero_one else 2 * x - 1 |
|
|