|
import os |
|
import gzip |
|
import struct |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torchvision.transforms as TF |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
from torch.utils.data import Dataset |
|
from typing import Tuple |
|
from PIL import Image |
|
from skimage.io import imread |
|
|
|
|
|
def log_standardize(x): |
|
log_x = torch.log(x.clamp(min=1e-12)) |
|
return (log_x - log_x.mean()) / log_x.std().clamp(min=1e-12) |
|
|
|
|
|
def normalize(x, x_min=None, x_max=None, zero_one=False): |
|
if x_min is None: |
|
x_min = x.min() |
|
if x_max is None: |
|
x_max = x.max() |
|
print(f"max: {x_max}, min: {x_min}") |
|
x = (x - x_min) / (x_max - x_min) |
|
return x if zero_one else 2 * x - 1 |
|
|
|
|
|
class UKBBDataset(Dataset): |
|
def __init__( |
|
self, root, csv_file, transform=None, columns=None, norm=None, concat_pa=True |
|
): |
|
super().__init__() |
|
self.root = root |
|
self.transform = transform |
|
self.concat_pa = concat_pa |
|
|
|
print(f"\nLoading csv data: {csv_file}") |
|
self.df = pd.read_csv(csv_file) |
|
self.columns = columns |
|
if self.columns is None: |
|
|
|
self.columns = list(self.df.columns) |
|
self.columns.pop(0) |
|
print(f"columns: {self.columns}") |
|
self.samples = {i: torch.as_tensor(self.df[i]).float() for i in self.columns} |
|
|
|
for k in ["age", "brain_volume", "ventricle_volume"]: |
|
print(f"{k} normalization: {norm}") |
|
if k in self.columns: |
|
if norm == "[-1,1]": |
|
self.samples[k] = normalize(self.samples[k]) |
|
elif norm == "[0,1]": |
|
self.samples[k] = normalize(self.samples[k], zero_one=True) |
|
elif norm == "log_standard": |
|
self.samples[k] = log_standardize(self.samples[k]) |
|
elif norm == None: |
|
pass |
|
else: |
|
NotImplementedError(f"{norm} not implemented.") |
|
print(f"#samples: {len(self.df)}") |
|
self.return_x = True if "eid" in self.columns else False |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
sample = {k: v[idx] for k, v in self.samples.items()} |
|
|
|
if self.return_x: |
|
mri_seq = "T1" if sample["mri_seq"] == 0.0 else "T2_FLAIR" |
|
|
|
filename = ( |
|
f'{int(sample["eid"])}_' + mri_seq + "_unbiased_brain_rigid_to_mni.png" |
|
) |
|
x = Image.open(os.path.join(self.root, "thumbs_192x192", filename)) |
|
|
|
if self.transform is not None: |
|
sample["x"] = self.transform(x) |
|
sample.pop("eid", None) |
|
|
|
if self.concat_pa: |
|
sample["pa"] = torch.cat( |
|
[torch.tensor([sample[k]]) for k in self.columns if k != "eid"], dim=0 |
|
) |
|
|
|
return sample |
|
|
|
|
|
def get_attr_max_min(attr): |
|
|
|
if attr == "age": |
|
return 73, 44 |
|
elif attr == "brain_volume": |
|
return 1629520, 841919 |
|
elif attr == "ventricle_volume": |
|
return 157075, 7613.27001953125 |
|
else: |
|
NotImplementedError |
|
|
|
|
|
def ukbb(args): |
|
csv_dir = args.data_dir |
|
augmentation = { |
|
"train": TF.Compose( |
|
[ |
|
TF.Resize((args.input_res, args.input_res), antialias=None), |
|
TF.RandomCrop( |
|
size=(args.input_res, args.input_res), |
|
padding=[2 * args.pad, args.pad], |
|
), |
|
TF.RandomHorizontalFlip(p=args.hflip), |
|
TF.PILToTensor(), |
|
] |
|
), |
|
"eval": TF.Compose( |
|
[ |
|
TF.Resize((args.input_res, args.input_res), antialias=None), |
|
TF.PILToTensor(), |
|
] |
|
), |
|
} |
|
|
|
datasets = {} |
|
|
|
for split in ["test"]: |
|
datasets[split] = UKBBDataset( |
|
root=args.data_dir, |
|
csv_file=os.path.join(csv_dir, split + ".csv"), |
|
transform=augmentation[("eval" if split != "train" else split)], |
|
columns=(None if not args.parents_x else ["eid"] + args.parents_x), |
|
norm=(None if not hasattr(args, "context_norm") else args.context_norm), |
|
concat_pa=False, |
|
) |
|
|
|
return datasets |
|
|
|
|
|
def _load_uint8(f): |
|
idx_dtype, ndim = struct.unpack("BBBB", f.read(4))[2:] |
|
shape = struct.unpack(">" + "I" * ndim, f.read(4 * ndim)) |
|
buffer_length = int(np.prod(shape)) |
|
data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape) |
|
return data |
|
|
|
|
|
def load_idx(path: str) -> np.ndarray: |
|
"""Reads an array in IDX format from disk. |
|
Parameters |
|
---------- |
|
path : str |
|
Path of the input file. Will uncompress with `gzip` if path ends in '.gz'. |
|
Returns |
|
------- |
|
np.ndarray |
|
Output array of dtype ``uint8``. |
|
References |
|
---------- |
|
http://yann.lecun.com/exdb/mnist/ |
|
""" |
|
open_fcn = gzip.open if path.endswith(".gz") else open |
|
with open_fcn(path, "rb") as f: |
|
return _load_uint8(f) |
|
|
|
|
|
def _get_paths(root_dir, train): |
|
prefix = "train" if train else "t10k" |
|
images_filename = prefix + "-images-idx3-ubyte.gz" |
|
labels_filename = prefix + "-labels-idx1-ubyte.gz" |
|
metrics_filename = prefix + "-morpho.csv" |
|
images_path = os.path.join(root_dir, images_filename) |
|
labels_path = os.path.join(root_dir, labels_filename) |
|
metrics_path = os.path.join(root_dir, metrics_filename) |
|
return images_path, labels_path, metrics_path |
|
|
|
|
|
def load_morphomnist_like( |
|
root_dir, train: bool = True, columns=None |
|
) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]: |
|
""" |
|
Args: |
|
root_dir: path to data directory |
|
train: whether to load the training subset (``True``, ``'train-*'`` files) or the test |
|
subset (``False``, ``'t10k-*'`` files) |
|
columns: list of morphometrics to load; by default (``None``) loads the image index and |
|
all available metrics: area, length, thickness, slant, width, and height |
|
Returns: |
|
images, labels, metrics |
|
""" |
|
images_path, labels_path, metrics_path = _get_paths(root_dir, train) |
|
images = load_idx(images_path) |
|
labels = load_idx(labels_path) |
|
|
|
if columns is not None and "index" not in columns: |
|
usecols = ["index"] + list(columns) |
|
else: |
|
usecols = columns |
|
metrics = pd.read_csv(metrics_path, usecols=usecols, index_col="index") |
|
return images, labels, metrics |
|
|
|
|
|
class MorphoMNIST(Dataset): |
|
def __init__( |
|
self, |
|
root_dir, |
|
train=True, |
|
transform=None, |
|
columns=None, |
|
norm=None, |
|
concat_pa=True, |
|
): |
|
self.train = train |
|
self.transform = transform |
|
self.columns = columns |
|
self.concat_pa = concat_pa |
|
self.norm = norm |
|
|
|
cols_not_digit = [c for c in self.columns if c != "digit"] |
|
images, labels, metrics_df = load_morphomnist_like( |
|
root_dir, train, cols_not_digit |
|
) |
|
self.images = torch.from_numpy(np.array(images)).unsqueeze(1) |
|
self.labels = F.one_hot( |
|
torch.from_numpy(np.array(labels)).long(), num_classes=10 |
|
) |
|
|
|
if self.columns is None: |
|
self.columns = metrics_df.columns |
|
self.samples = {k: torch.tensor(metrics_df[k]) for k in cols_not_digit} |
|
|
|
self.min_max = { |
|
"thickness": [0.87598526, 6.255515], |
|
"intensity": [66.601204, 254.90317], |
|
} |
|
|
|
for k, v in self.samples.items(): |
|
print(f"{k} normalization: {norm}") |
|
if norm == "[-1,1]": |
|
self.samples[k] = normalize( |
|
v, x_min=self.min_max[k][0], x_max=self.min_max[k][1] |
|
) |
|
elif norm == "[0,1]": |
|
self.samples[k] = normalize( |
|
v, x_min=self.min_max[k][0], x_max=self.min_max[k][1], zero_one=True |
|
) |
|
elif norm == None: |
|
pass |
|
else: |
|
NotImplementedError(f"{norm} not implemented.") |
|
print(f"#samples: {len(metrics_df)}\n") |
|
|
|
self.samples.update({"digit": self.labels}) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
sample = {} |
|
sample["x"] = self.images[idx] |
|
|
|
if self.transform is not None: |
|
sample["x"] = self.transform(sample["x"]) |
|
|
|
if self.concat_pa: |
|
sample["pa"] = torch.cat( |
|
[ |
|
v[idx] if k == "digit" else torch.tensor([v[idx]]) |
|
for k, v in self.samples.items() |
|
], |
|
dim=0, |
|
) |
|
else: |
|
sample.update({k: v[idx] for k, v in self.samples.items()}) |
|
return sample |
|
|
|
|
|
def morphomnist(args): |
|
|
|
augmentation = { |
|
"train": TF.Compose( |
|
[ |
|
TF.RandomCrop((args.input_res, args.input_res), padding=args.pad), |
|
] |
|
), |
|
"eval": TF.Compose( |
|
[ |
|
TF.Pad(padding=2), |
|
] |
|
), |
|
} |
|
|
|
datasets = {} |
|
|
|
for split in ["test"]: |
|
datasets[split] = MorphoMNIST( |
|
root_dir=args.data_dir, |
|
train=(split == "train"), |
|
transform=augmentation[("eval" if split != "train" else split)], |
|
columns=args.parents_x, |
|
norm=args.context_norm, |
|
concat_pa=False, |
|
) |
|
return datasets |
|
|
|
|
|
def preproc_mimic(batch): |
|
for k, v in batch.items(): |
|
if k == "x": |
|
batch["x"] = (batch["x"].float() - 127.5) / 127.5 |
|
elif k in ["age"]: |
|
batch[k] = batch[k].float().unsqueeze(-1) |
|
batch[k] = batch[k] / 100.0 |
|
batch[k] = batch[k] * 2 - 1 |
|
elif k in ["race"]: |
|
batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float() |
|
elif k in ["finding"]: |
|
batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float() |
|
else: |
|
batch[k] = batch[k].float().unsqueeze(-1) |
|
return batch |
|
|
|
|
|
class MIMICDataset(Dataset): |
|
def __init__( |
|
self, |
|
root, |
|
csv_file, |
|
transform=None, |
|
columns=None, |
|
concat_pa=True, |
|
only_pleural_eff=True, |
|
): |
|
self.data = pd.read_csv(csv_file) |
|
self.transform = transform |
|
self.disease_labels = [ |
|
"No Finding", |
|
"Pleural Effusion", |
|
"Pneumonia", |
|
|
|
] |
|
self.samples = { |
|
"age": [], |
|
"sex": [], |
|
"finding": [], |
|
"x": [], |
|
"race": [], |
|
|
|
|
|
} |
|
|
|
for idx, _ in enumerate(tqdm(range(len(self.data)), desc="Loading MIMIC Data")): |
|
if only_pleural_eff and self.data.loc[idx, "disease"] == "Other": |
|
continue |
|
img_path = os.path.join(root, self.data.loc[idx, "path_preproc"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disease = self.data.loc[idx, "disease"] |
|
|
|
if disease == "No Finding": |
|
finding = 0 |
|
elif disease == "Pleural Effusion": |
|
finding = 1 |
|
elif disease == "Pneumonia": |
|
finding = 2 |
|
else: |
|
finding = 0 |
|
|
|
|
|
self.samples["x"].append(img_path) |
|
self.samples["finding"].append(finding) |
|
self.samples["age"].append(self.data.loc[idx, "age"]) |
|
self.samples["race"].append(self.data.loc[idx, "race_label"]) |
|
self.samples["sex"].append(self.data.loc[idx, "sex_label"]) |
|
|
|
self.columns = columns |
|
if self.columns is None: |
|
|
|
self.columns = list(self.data.columns) |
|
self.columns.pop(0) |
|
self.concat_pa = concat_pa |
|
|
|
def __len__(self): |
|
return len(self.samples["x"]) |
|
|
|
def __getitem__(self, idx): |
|
sample = {k: v[idx] for k, v in self.samples.items()} |
|
sample["x"] = imread(sample["x"]).astype(np.float32)[None, ...] |
|
|
|
for k, v in sample.items(): |
|
sample[k] = torch.tensor(v) |
|
|
|
if self.transform: |
|
sample["x"] = self.transform(sample["x"]) |
|
|
|
sample = preproc_mimic(sample) |
|
if self.concat_pa: |
|
sample["pa"] = torch.cat([sample[k] for k in self.columns], dim=0) |
|
return sample |
|
|
|
|
|
def mimic(args): |
|
args.csv_dir = args.data_dir |
|
datasets = {} |
|
datasets["test"] = MIMICDataset( |
|
root=args.data_dir, |
|
csv_file=os.path.join(args.csv_dir, "mimic.sample.test.csv"), |
|
columns=args.parents_x, |
|
transform=TF.Compose( |
|
[ |
|
TF.Resize((args.input_res, args.input_res), antialias=None), |
|
] |
|
), |
|
concat_pa=False, |
|
) |
|
return datasets |
|
|