|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn, Tensor |
|
import torch.nn.functional as F |
|
|
|
import pyro |
|
import pyro.distributions as dist |
|
import pyro.distributions.transforms as T |
|
|
|
from pyro.nn import DenseNN |
|
from pyro.infer.reparam.transform import TransformReparam |
|
from pyro.distributions.conditional import ConditionalTransformedDistribution |
|
|
|
from .layers import ( |
|
ConditionalTransformedDistributionGumbelMax, |
|
ConditionalGumbelMax, |
|
ConditionalAffineTransform, |
|
MLP, |
|
CNN, |
|
) |
|
|
|
|
|
class Hparams: |
|
def update(self, dict): |
|
for k, v in dict.items(): |
|
setattr(self, k, v) |
|
|
|
|
|
def is_one_hot(tensor): |
|
""" |
|
Check if the given tensor is a valid one-hot tensor. |
|
|
|
Args: |
|
tensor (torch.Tensor): A tensor to check. |
|
|
|
Returns: |
|
bool: True if tensor is one-hot, False otherwise. |
|
""" |
|
if tensor.ndim != 2: |
|
return False |
|
|
|
|
|
return torch.all((tensor.sum(dim=1) == 1) & (tensor.max(dim=1).values == 1) & (tensor.min(dim=1).values == 0)) |
|
|
|
|
|
class BasePGM(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def scm(self, *args, **kwargs): |
|
def config(msg): |
|
if isinstance(msg["fn"], dist.TransformedDistribution): |
|
return TransformReparam() |
|
else: |
|
return None |
|
|
|
return pyro.poutine.reparam(self.model, config=config)(*args, **kwargs) |
|
|
|
def sample_scm(self, n_samples: int = 1): |
|
with pyro.plate("obs", n_samples): |
|
samples = self.scm() |
|
return samples |
|
|
|
def sample(self, n_samples: int = 1): |
|
with pyro.plate("obs", n_samples): |
|
samples = self.model() |
|
return samples |
|
|
|
def infer_exogeneous(self, obs: Dict[str, Tensor]) -> Dict[str, Tensor]: |
|
batch_size = list(obs.values())[0].shape[0] |
|
|
|
cond_model = pyro.condition(self.sample, data=obs) |
|
cond_trace = pyro.poutine.trace(cond_model).get_trace(batch_size) |
|
|
|
output = {} |
|
for name, node in cond_trace.nodes.items(): |
|
if "z" in name or "fn" not in node.keys(): |
|
continue |
|
fn = node["fn"] |
|
if isinstance(fn, dist.Independent): |
|
fn = fn.base_dist |
|
if isinstance(fn, dist.TransformedDistribution): |
|
|
|
output[name + "_base"] = T.ComposeTransform(fn.transforms).inv( |
|
node["value"] |
|
) |
|
return output |
|
|
|
def counterfactual( |
|
self, |
|
obs: Dict[str, Tensor], |
|
intervention: Dict[str, Tensor], |
|
num_particles: int = 1, |
|
detach: bool = True, |
|
) -> Dict[str, Tensor]: |
|
|
|
dag_variables = self.variables.keys() |
|
assert set(obs.keys()) == set(dag_variables) |
|
avg_cfs = {k: torch.zeros_like(obs[k]) for k in obs.keys()} |
|
batch_size = list(obs.values())[0].shape[0] |
|
|
|
for _ in range(num_particles): |
|
|
|
exo_noise = self.infer_exogeneous(obs) |
|
exo_noise = {k: v.detach() if detach else v for k, v in exo_noise.items()} |
|
|
|
for k in dag_variables: |
|
if k not in intervention.keys(): |
|
if k not in [i.split("_base")[0] for i in exo_noise.keys()]: |
|
exo_noise[k] = obs[k] |
|
|
|
abducted_scm = pyro.poutine.condition(self.sample_scm, data=exo_noise) |
|
|
|
counterfactual_scm = pyro.poutine.do(abducted_scm, data=intervention) |
|
|
|
counterfactuals = counterfactual_scm(batch_size) |
|
|
|
if hasattr(self, "discrete_variables"): |
|
|
|
|
|
|
|
if ( |
|
"age" not in intervention.keys() |
|
and "finding" not in intervention.keys() |
|
): |
|
counterfactuals["finding"] = obs["finding"] |
|
|
|
for k, v in counterfactuals.items(): |
|
avg_cfs[k] += v / num_particles |
|
return avg_cfs |
|
|
|
|
|
class FlowPGM(BasePGM): |
|
def __init__(self, args: Hparams): |
|
super().__init__() |
|
self.variables = { |
|
"sex": "binary", |
|
"mri_seq": "binary", |
|
"age": "continuous", |
|
"brain_volume": "continuous", |
|
"ventricle_volume": "continuous", |
|
} |
|
|
|
self.s_logit = nn.Parameter(torch.zeros(1)) |
|
self.m_logit = nn.Parameter(torch.zeros(1)) |
|
for k in ["a", "b", "v"]: |
|
self.register_buffer(f"{k}_base_loc", torch.zeros(1)) |
|
self.register_buffer(f"{k}_base_scale", torch.ones(1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.age_module = T.ComposeTransformModule( |
|
[T.Spline(1, count_bins=4, order="linear")] |
|
) |
|
self.age_flow = T.ComposeTransform([self.age_module]) |
|
|
|
|
|
|
|
bvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1)) |
|
self.bvol_flow = ConditionalAffineTransform(context_nn=bvol_net, event_dim=0) |
|
|
|
|
|
|
|
vvol_net = DenseNN(2, args.widths, [1, 1], nonlinearity=nn.LeakyReLU(0.1)) |
|
self.vvol_flow = ConditionalAffineTransform(context_nn=vvol_net, event_dim=0) |
|
|
|
|
|
|
|
|
|
input_shape = (args.input_channels, args.input_res, args.input_res) |
|
|
|
self.encoder_s = CNN(input_shape, num_outputs=1, context_dim=1) |
|
|
|
self.encoder_m = CNN(input_shape, num_outputs=1) |
|
|
|
self.encoder_a = MLP(num_inputs=2, num_outputs=2) |
|
|
|
self.encoder_b = CNN(input_shape, num_outputs=2, context_dim=1) |
|
|
|
self.encoder_v = CNN(input_shape, num_outputs=2) |
|
self.f = ( |
|
lambda x: args.std_fixed * torch.ones_like(x) |
|
if args.std_fixed > 0 |
|
else F.softplus(x) |
|
) |
|
|
|
def model(self) -> Dict[str, Tensor]: |
|
|
|
ps = dist.Bernoulli(logits=self.s_logit).to_event(1) |
|
sex = pyro.sample("sex", ps) |
|
|
|
|
|
pm = dist.Bernoulli(logits=self.m_logit).to_event(1) |
|
mri_seq = pyro.sample("mri_seq", pm) |
|
|
|
|
|
pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1) |
|
pa = dist.TransformedDistribution(pa_base, self.age_flow) |
|
age = pyro.sample("age", pa) |
|
|
|
|
|
pb_sa_base = dist.Normal(self.b_base_loc, self.b_base_scale).to_event(1) |
|
pb_sa = ConditionalTransformedDistribution( |
|
pb_sa_base, [self.bvol_flow] |
|
).condition(torch.cat([sex, age], dim=1)) |
|
bvol = pyro.sample("brain_volume", pb_sa) |
|
|
|
|
|
|
|
pv_ba_base = dist.Normal(self.v_base_loc, self.v_base_scale).to_event(1) |
|
pv_ba = ConditionalTransformedDistribution( |
|
pv_ba_base, [self.vvol_flow] |
|
).condition(torch.cat([bvol, age], dim=1)) |
|
vvol = pyro.sample("ventricle_volume", pv_ba) |
|
|
|
|
|
return { |
|
"sex": sex, |
|
"mri_seq": mri_seq, |
|
"age": age, |
|
"brain_volume": bvol, |
|
"ventricle_volume": vvol, |
|
} |
|
|
|
def guide(self, **obs) -> None: |
|
|
|
pyro.module("FlowPGM", self) |
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
|
|
if obs["mri_seq"] is None: |
|
m_prob = torch.sigmoid(self.encoder_m(obs["x"])) |
|
m = pyro.sample("mri_seq", dist.Bernoulli(probs=m_prob).to_event(1)) |
|
|
|
|
|
if obs["ventricle_volume"] is None: |
|
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1) |
|
qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1) |
|
obs["ventricle_volume"] = pyro.sample("ventricle_volume", qv_x) |
|
|
|
|
|
if obs["brain_volume"] is None: |
|
b_loc, b_logscale = self.encoder_b( |
|
obs["x"], y=obs["ventricle_volume"] |
|
).chunk(2, dim=-1) |
|
qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1) |
|
obs["brain_volume"] = pyro.sample("brain_volume", qb_xv) |
|
|
|
|
|
if obs["sex"] is None: |
|
s_prob = torch.sigmoid( |
|
self.encoder_s(obs["x"], y=obs["brain_volume"]) |
|
) |
|
pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1)) |
|
|
|
|
|
if obs["age"] is None: |
|
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1) |
|
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1) |
|
pyro.sample("age", dist.Normal(a_loc, self.f(a_logscale)).to_event(1)) |
|
|
|
def model_anticausal(self, **obs) -> None: |
|
|
|
pyro.module("FlowPGM", self) |
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
|
|
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1) |
|
qv_x = dist.Normal(v_loc, self.f(v_logscale)).to_event(1) |
|
pyro.sample("ventricle_volume_aux", qv_x, obs=obs["ventricle_volume"]) |
|
|
|
|
|
b_loc, b_logscale = self.encoder_b( |
|
obs["x"], y=obs["ventricle_volume"] |
|
).chunk(2, dim=-1) |
|
qb_xv = dist.Normal(b_loc, self.f(b_logscale)).to_event(1) |
|
pyro.sample("brain_volume_aux", qb_xv, obs=obs["brain_volume"]) |
|
|
|
|
|
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1) |
|
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1) |
|
pyro.sample( |
|
"age_aux", |
|
dist.Normal(a_loc, self.f(a_logscale)).to_event(1), |
|
obs=obs["age"], |
|
) |
|
|
|
|
|
s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"])) |
|
qs_xb = dist.Bernoulli(probs=s_prob).to_event(1) |
|
pyro.sample("sex_aux", qs_xb, obs=obs["sex"]) |
|
|
|
|
|
m_prob = torch.sigmoid(self.encoder_m(obs["x"])) |
|
qm_x = dist.Bernoulli(probs=m_prob).to_event(1) |
|
pyro.sample("mri_seq_aux", qm_x, obs=obs["mri_seq"]) |
|
|
|
def predict(self, **obs) -> Dict[str, Tensor]: |
|
|
|
v_loc, v_logscale = self.encoder_v(obs["x"]).chunk(2, dim=-1) |
|
|
|
|
|
b_loc, b_logscale = self.encoder_b(obs["x"], y=obs["ventricle_volume"]).chunk( |
|
2, dim=-1 |
|
) |
|
|
|
|
|
ctx = torch.cat([obs["brain_volume"], obs["ventricle_volume"]], dim=-1) |
|
a_loc, a_logscale = self.encoder_a(ctx).chunk(2, dim=-1) |
|
|
|
|
|
s_prob = torch.sigmoid(self.encoder_s(obs["x"], y=obs["brain_volume"])) |
|
|
|
m_prob = torch.sigmoid(self.encoder_m(obs["x"])) |
|
|
|
return { |
|
"sex": s_prob, |
|
"mri_seq": m_prob, |
|
"age": a_loc, |
|
"brain_volume": b_loc, |
|
"ventricle_volume": v_loc, |
|
} |
|
|
|
def svi_model(self, **obs) -> None: |
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
pyro.condition(self.model, data=obs)() |
|
|
|
def guide_pass(self, **obs) -> None: |
|
pass |
|
|
|
|
|
class MorphoMNISTPGM(BasePGM): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.variables = { |
|
"thickness": "continuous", |
|
"intensity": "continuous", |
|
"digit": "categorical", |
|
} |
|
|
|
self.digit_logits = nn.Parameter(torch.zeros(1, 10)) |
|
for k in ["t", "i"]: |
|
self.register_buffer(f"{k}_base_loc", torch.zeros(1)) |
|
self.register_buffer(f"{k}_base_scale", torch.ones(1)) |
|
|
|
|
|
normalize_transform = T.ComposeTransform( |
|
[T.SigmoidTransform(), T.AffineTransform(loc=-1, scale=2)] |
|
) |
|
|
|
|
|
self.thickness_module = T.ComposeTransformModule( |
|
[T.Spline(1, count_bins=4, order="linear")] |
|
) |
|
self.thickness_flow = T.ComposeTransform( |
|
[self.thickness_module, normalize_transform] |
|
) |
|
|
|
|
|
intensity_net = DenseNN(1, args.widths, [1, 1], nonlinearity=nn.GELU()) |
|
self.context_nn = ConditionalAffineTransform( |
|
context_nn=intensity_net, event_dim=0 |
|
) |
|
self.intensity_flow = [self.context_nn, normalize_transform] |
|
|
|
if args.setup != "sup_pgm": |
|
|
|
input_shape = (args.input_channels, args.input_res, args.input_res) |
|
|
|
self.encoder_t = CNN(input_shape, num_outputs=2, context_dim=1, width=8) |
|
|
|
self.encoder_i = CNN(input_shape, num_outputs=2, width=8) |
|
|
|
self.encoder_y = CNN(input_shape, num_outputs=10, width=8) |
|
self.f = ( |
|
lambda x: args.std_fixed * torch.ones_like(x) |
|
if args.std_fixed > 0 |
|
else F.softplus(x) |
|
) |
|
|
|
def model(self) -> Dict[str, Tensor]: |
|
pyro.module("MorphoMNISTPGM", self) |
|
|
|
py = dist.OneHotCategorical( |
|
probs=F.softmax(self.digit_logits, dim=-1) |
|
) |
|
|
|
digit = pyro.sample("digit", py) |
|
|
|
|
|
pt_base = dist.Normal(self.t_base_loc, self.t_base_scale).to_event(1) |
|
pt = dist.TransformedDistribution(pt_base, self.thickness_flow) |
|
thickness = pyro.sample("thickness", pt) |
|
|
|
|
|
pi_t_base = dist.Normal(self.i_base_loc, self.i_base_scale).to_event(1) |
|
pi_t = ConditionalTransformedDistribution( |
|
pi_t_base, self.intensity_flow |
|
).condition(thickness) |
|
intensity = pyro.sample("intensity", pi_t) |
|
_ = self.context_nn |
|
|
|
return {"thickness": thickness, "intensity": intensity, "digit": digit} |
|
|
|
def guide(self, **obs) -> None: |
|
|
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
|
|
if obs["intensity"] is None: |
|
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1) |
|
qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1) |
|
obs["intensity"] = pyro.sample("intensity", qi_t) |
|
|
|
|
|
if obs["thickness"] is None: |
|
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk( |
|
2, dim=-1 |
|
) |
|
qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1) |
|
obs["thickness"] = pyro.sample("thickness", qt_x) |
|
|
|
|
|
if obs["digit"] is None: |
|
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1) |
|
qy_x = dist.OneHotCategorical(probs=y_prob) |
|
pyro.sample("digit", qy_x) |
|
|
|
def model_anticausal(self, **obs) -> None: |
|
|
|
pyro.module("MorphoMNISTPGM", self) |
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
|
|
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk( |
|
2, dim=-1 |
|
) |
|
qt_x = dist.Normal(torch.tanh(t_loc), self.f(t_logscale)).to_event(1) |
|
pyro.sample("thickness_aux", qt_x, obs=obs["thickness"]) |
|
|
|
|
|
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1) |
|
qi_t = dist.Normal(torch.tanh(i_loc), self.f(i_logscale)).to_event(1) |
|
pyro.sample("intensity_aux", qi_t, obs=obs["intensity"]) |
|
|
|
|
|
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1) |
|
qy_x = dist.OneHotCategorical(probs=y_prob) |
|
pyro.sample("digit_aux", qy_x, obs=obs["digit"]) |
|
|
|
def predict(self, **obs) -> Dict[str, Tensor]: |
|
|
|
t_loc, t_logscale = self.encoder_t(obs["x"], y=obs["intensity"]).chunk( |
|
2, dim=-1 |
|
) |
|
t_loc = torch.tanh(t_loc) |
|
|
|
i_loc, i_logscale = self.encoder_i(obs["x"]).chunk(2, dim=-1) |
|
i_loc = torch.tanh(i_loc) |
|
|
|
y_prob = F.softmax(self.encoder_y(obs["x"]), dim=-1) |
|
return {"thickness": t_loc, "intensity": i_loc, "digit": y_prob} |
|
|
|
def svi_model(self, **obs) -> None: |
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
pyro.condition(self.model, data=obs)() |
|
|
|
def guide_pass(self, **obs) -> None: |
|
pass |
|
|
|
|
|
class ChestPGM(BasePGM): |
|
def __init__(self, args: Hparams): |
|
super().__init__() |
|
self.variables = { |
|
"race": "categorical", |
|
"sex": "binary", |
|
"finding": "categorical", |
|
"age": "continuous", |
|
} |
|
|
|
self.discrete_variables = {"finding": "categorical"} |
|
|
|
for k in ["a", "f"]: |
|
self.register_buffer(f"{k}_base_loc", torch.zeros(1)) |
|
self.register_buffer(f"{k}_base_scale", torch.ones(1)) |
|
|
|
self.age_flow_components = T.ComposeTransformModule([T.Spline(1)]) |
|
|
|
|
|
|
|
self.age_flow = T.ComposeTransform( |
|
[ |
|
self.age_flow_components, |
|
|
|
] |
|
) |
|
|
|
finding_net = DenseNN(1, [8, 16], param_dims=[3], nonlinearity=nn.Softmax()) |
|
self.finding_transform_GumbelMax = ConditionalGumbelMax( |
|
context_nn=finding_net, event_dim=0 |
|
) |
|
|
|
self.sex_logit = nn.Parameter(np.log(1 / 2) * torch.ones(1)) |
|
self.race_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3)) |
|
self.finding_logits = nn.Parameter(np.log(1 / 3) * torch.ones(1, 3)) |
|
|
|
if args.setup != "sup_pgm": |
|
from resnet import CustomBlock, ResNet, ResNet18 |
|
|
|
shared_model = ResNet( |
|
CustomBlock, |
|
layers=[2, 2, 2, 2], |
|
widths=[64, 128, 256, 512], |
|
norm_layer=lambda c: nn.GroupNorm(min(32, c // 4), c), |
|
) |
|
|
|
shared_model.conv1 = nn.Conv2d( |
|
args.input_channels, |
|
64, |
|
kernel_size=7, |
|
stride=2, |
|
padding=3, |
|
bias=False, |
|
) |
|
kwargs = { |
|
"in_shape": (args.input_channels, *(args.input_res,) * 2), |
|
"base_model": shared_model, |
|
} |
|
|
|
self.encoder_s = ResNet18(num_outputs=1, **kwargs) |
|
|
|
self.encoder_r = ResNet18(num_outputs=3, **kwargs) |
|
|
|
self.encoder_f = ResNet18(num_outputs=3, **kwargs) |
|
|
|
self.encoder_a = ResNet18(num_outputs=2, context_dim=3, **kwargs) |
|
self.f = ( |
|
lambda x: args.std_fixed * torch.ones_like(x) |
|
if args.std_fixed > 0 |
|
else F.softplus(x) |
|
) |
|
|
|
def model(self) -> Dict[str, Tensor]: |
|
pyro.module("ChestPGM", self) |
|
|
|
ps = dist.Bernoulli(logits=self.sex_logit).to_event(1) |
|
sex = pyro.sample("sex", ps) |
|
|
|
|
|
pa_base = dist.Normal(self.a_base_loc, self.a_base_scale).to_event(1) |
|
pa = dist.TransformedDistribution(pa_base, self.age_flow) |
|
age = pyro.sample("age", pa) |
|
|
|
_ = self.age_flow_components |
|
|
|
|
|
pr = dist.OneHotCategorical(logits=self.race_logits) |
|
race = pyro.sample("race", pr) |
|
|
|
|
|
|
|
|
|
|
|
|
|
finding_dist_base = dist.OneHotCategorical(logits=self.finding_logits) |
|
|
|
finding_dist = ConditionalTransformedDistributionGumbelMax( |
|
finding_dist_base, [self.finding_transform_GumbelMax]).condition(age) |
|
|
|
finding = pyro.sample("finding", finding_dist) |
|
|
|
return { |
|
"sex": sex, |
|
"race": race, |
|
"age": age, |
|
"finding": finding, |
|
} |
|
|
|
def guide(self, **obs) -> None: |
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
|
|
if obs["sex"] is None: |
|
s_prob = torch.sigmoid(self.encoder_s(obs["x"])) |
|
pyro.sample("sex", dist.Bernoulli(probs=s_prob).to_event(1)) |
|
|
|
if obs["race"] is None: |
|
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1) |
|
qr_x = dist.OneHotCategorical(probs=r_probs) |
|
pyro.sample("race", qr_x) |
|
|
|
if obs["finding"] is None: |
|
f_prob = F.softmax(self.encoder_f(obs["x"]),dim=-1) |
|
qf_x = dist.OneHotCategorical(probs=f_prob).to_event(1) |
|
obs["finding"] = pyro.sample("finding", qf_x) |
|
|
|
if obs["age"] is None: |
|
a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk( |
|
2, dim=-1 |
|
) |
|
qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1) |
|
pyro.sample("age_aux", qa_xf) |
|
|
|
def model_anticausal(self, **obs) -> None: |
|
|
|
pyro.module("ChestPGM", self) |
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
|
|
s_prob = torch.sigmoid(self.encoder_s(obs["x"])) |
|
qs_x = dist.Bernoulli(probs=s_prob).to_event(1) |
|
|
|
pyro.sample("sex_aux", qs_x, obs=obs["sex"]) |
|
|
|
|
|
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1) |
|
qr_x = dist.OneHotCategorical(probs=r_probs) |
|
|
|
pyro.sample("race_aux", qr_x, obs=obs["race"]) |
|
|
|
|
|
f_probs = F.softmax(self.encoder_f(obs["x"]), dim=-1) |
|
qf_x= dist.OneHotCategorical(probs=f_probs) |
|
|
|
pyro.sample("finding_aux", qf_x, obs=obs["finding"]) |
|
|
|
|
|
a_loc, a_logscale = self.encoder_a(obs["x"], y=obs["finding"]).chunk( |
|
2, dim=-1 |
|
) |
|
qa_xf = dist.Normal(a_loc, self.f(a_logscale)).to_event(1) |
|
|
|
pyro.sample("age_aux", qa_xf, obs=obs["age"]) |
|
|
|
def predict(self, **obs) -> Dict[str, Tensor]: |
|
|
|
s_prob = torch.sigmoid(self.encoder_s(obs["x"])) |
|
|
|
r_probs = F.softmax(self.encoder_r(obs["x"]), dim=-1) |
|
|
|
f_prob = F.softmax(self.encoder_f(obs["x"]),dim=-1) |
|
|
|
a_loc, _ = self.encoder_a(obs["x"], y=obs["finding"]).chunk(2, dim=-1) |
|
|
|
return { |
|
"sex": s_prob, |
|
"race": r_probs, |
|
"finding": f_prob, |
|
"age": a_loc, |
|
} |
|
|
|
def svi_model(self, **obs) -> None: |
|
with pyro.plate("observations", obs["x"].shape[0]): |
|
pyro.condition(self.model, data=obs)() |
|
|
|
def guide_pass(self, **obs) -> None: |
|
pass |
|
|