For convinience, I've unified all of the data preprocessing
notebooks from [ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation.git) into a single notebook

In [1]:
%reset
%load_ext autoreload
%autoreload 2
import numpy as np
import os
import glob
import SimpleITK as sitk
import sys

sys.path.insert(0, '../')

In [2]:
# Create dirs for the SABS and CHAOS datasets
os.makedirs('./SABS', exist_ok=True)
os.makedirs('./CHAOST2', exist_ok=True)

In [6]:
def copy_spacing_ori(src, dst):
 dst.SetSpacing(src.GetSpacing())
 dst.SetOrigin(src.GetOrigin())
 dst.SetDirection(src.GetDirection())
 return dst

# helper functions copy pasted
def resample_by_res(mov_img_obj, new_spacing, interpolator = sitk.sitkLinear, logging = True):
 resample = sitk.ResampleImageFilter()
 resample.SetInterpolator(interpolator)
 resample.SetOutputDirection(mov_img_obj.GetDirection())
 resample.SetOutputOrigin(mov_img_obj.GetOrigin())
 mov_spacing = mov_img_obj.GetSpacing()

 resample.SetOutputSpacing(new_spacing)
 RES_COE = np.array(mov_spacing) * 1.0 / np.array(new_spacing)
 new_size = np.array(mov_img_obj.GetSize()) * RES_COE 

 resample.SetSize( [int(sz+1) for sz in new_size] )
 if logging:
 print("Spacing: {} -> {}".format(mov_spacing, new_spacing))
 print("Size {} -> {}".format( mov_img_obj.GetSize(), new_size ))

 return resample.Execute(mov_img_obj)

def resample_lb_by_res(mov_lb_obj, new_spacing, interpolator = sitk.sitkLinear, ref_img = None, logging = True):
 src_mat = sitk.GetArrayFromImage(mov_lb_obj)
 lbvs = np.unique(src_mat)
 if logging:
 print("Label values: {}".format(lbvs))
 for idx, lbv in enumerate(lbvs):
 _src_curr_mat = np.float32(src_mat == lbv) 
 _src_curr_obj = sitk.GetImageFromArray(_src_curr_mat)
 _src_curr_obj.CopyInformation(mov_lb_obj)
 _tar_curr_obj = resample_by_res( _src_curr_obj, new_spacing, interpolator, logging )
 _tar_curr_mat = np.rint(sitk.GetArrayFromImage(_tar_curr_obj)) * lbv
 if idx == 0:
 out_vol = _tar_curr_mat
 else:
 out_vol[_tar_curr_mat == lbv] = lbv
 out_obj = sitk.GetImageFromArray(out_vol)
 out_obj.SetSpacing( _tar_curr_obj.GetSpacing() )
 if ref_img != None:
 out_obj.CopyInformation(ref_img)
 return out_obj
 
## Then crop ROI
def get_label_center(label):
 nnz = np.sum(label > 1e-5)
 return np.int32(np.rint(np.sum(np.nonzero(label), axis = 1) * 1.0 / nnz))

def image_crop(ori_vol, crop_size, referece_ctr_idx, padval = 0., only_2d = True):
 """ crop a 3d matrix given the index of the new volume on the original volume
 Args:
 refernce_ctr_idx: the center of the new volume on the original volume (in indices)
 only_2d: only do cropping on first two dimensions
 """
 _expand_cropsize = [x + 1 for x in crop_size] # to deal with boundary case
 if only_2d:
 assert len(crop_size) == 2, "Actual len {}".format(len(crop_size))
 assert len(referece_ctr_idx) == 2, "Actual len {}".format(len(referece_ctr_idx))
 _expand_cropsize.append(ori_vol.shape[-1])
 
 image_patch = np.ones(tuple(_expand_cropsize)) * padval

 half_size = tuple( [int(x * 1.0 / 2) for x in _expand_cropsize] )
 _min_idx = [0,0,0]
 _max_idx = list(ori_vol.shape)

 # bias of actual cropped size to the beginning and the end of this volume
 _bias_start = [0,0,0]
 _bias_end = [0,0,0]

 for dim,hsize in enumerate(half_size):
 if dim == 2 and only_2d:
 break

 _bias_start[dim] = np.min([hsize, referece_ctr_idx[dim]])
 _bias_end[dim] = np.min([hsize, ori_vol.shape[dim] - referece_ctr_idx[dim]])

 _min_idx[dim] = referece_ctr_idx[dim] - _bias_start[dim]
 _max_idx[dim] = referece_ctr_idx[dim] + _bias_end[dim]
 
 if only_2d:
 image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \
 half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], ... ] = \
 ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \
 referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], ... ]

 image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], : ]
 # then goes back to original volume
 else:
 image_patch[ half_size[0] - _bias_start[0]: half_size[0] +_bias_end[0], \
 half_size[1] - _bias_start[1]: half_size[1] +_bias_end[1], \
 half_size[2] - _bias_start[2]: half_size[2] +_bias_end[2] ] = \
 ori_vol[ referece_ctr_idx[0] - _bias_start[0]: referece_ctr_idx[0] +_bias_end[0], \
 referece_ctr_idx[1] - _bias_start[1]: referece_ctr_idx[1] +_bias_end[1], \
 referece_ctr_idx[2] - _bias_start[2]: referece_ctr_idx[2] +_bias_end[2] ]

 image_patch = image_patch[ 0: crop_size[0], 0: crop_size[1], 0: crop_size[2] ]
 return image_patch

s2n = sitk.GetArrayFromImage


def resample_imgs(imgs, segs, pids, scan_dir, BD_BIAS, SPA_FAC, required_res=512):
 spa_fac = SPA_FAC
 for img_fid, seg_fid, pid in zip(imgs, segs, pids):

 # lb_n = nio.read_nii_bysitk(seg_fid)

 img_obj = sitk.ReadImage( img_fid )
 seg_obj = sitk.ReadImage( seg_fid )
 print(img_fid, seg_fid)
 ## image
 array = sitk.GetArrayFromImage(img_obj)
 H = W = array.shape[-1]
 if SPA_FAC is None:
 spa_fac = (H - 2 * BD_BIAS) / required_res
 print(array.shape, f"label shape {sitk.GetArrayFromImage(seg_obj).shape}")
 # cropping
 array = array[:, BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS]
 cropped_img_o = sitk.GetImageFromArray(array)
 cropped_img_o = copy_spacing_ori(img_obj, cropped_img_o)

 # resampling
 img_spa_ori = img_obj.GetSpacing()
 res_img_o = resample_by_res(cropped_img_o, [img_spa_ori[0] * spa_fac, img_spa_ori[1] * spa_fac, img_spa_ori[-1]], interpolator = sitk.sitkLinear,
 logging = True)

 ## label
 lb_arr = sitk.GetArrayFromImage(seg_obj)
 # cropping
 lb_arr = lb_arr[:,BD_BIAS: -BD_BIAS, BD_BIAS: -BD_BIAS]
 cropped_lb_o = sitk.GetImageFromArray(lb_arr)
 cropped_lb_o = copy_spacing_ori(seg_obj, cropped_lb_o)

 lb_spa_ori = seg_obj.GetSpacing()

 # resampling
 res_lb_o = resample_lb_by_res(cropped_lb_o, [lb_spa_ori[0] * spa_fac, lb_spa_ori[1] * spa_fac, lb_spa_ori[-1] ], interpolator = sitk.sitkLinear,
 ref_img = res_img_o, logging = True)

 
 out_img_fid = os.path.join( scan_dir, f'image_{pid}.nii.gz' )
 out_lb_fid = os.path.join( scan_dir, f'label_{pid}.nii.gz' ) 
 
 # then save
 sitk.WriteImage(res_img_o, out_img_fid, True) 
 sitk.WriteImage(res_lb_o, out_lb_fid, True) 
 print(f"{out_img_fid} has been saved, shape: {res_img_o.GetSize()}")
 print(f"{out_lb_fid} has been saved")

# Intensitiy Normalization for CT Images

In [7]:
# set up directories for images
IMG_FOLDER="./miccai2015/RawData/Training/img"
SEG_FOLDER="./miccai2015/RawData/Training/label"
OUT_FOLDER="./SABS/tmp_normalized/"

In [8]:
imgs = sorted(glob.glob(IMG_FOLDER + "/*.nii.gz"))
segs = sorted(glob.glob(SEG_FOLDER + "/*.nii.gz"))
pids = [pid.split("img")[-1].split(".")[0] for pid in imgs]
print(sorted(pids))
assert len(imgs) == len(segs)
for img, seg in zip(imgs, segs):
 print(img, seg)

[]


In [73]:
import copy
scan_dir = OUT_FOLDER
LIR = -125
HIR = 275
os.makedirs(scan_dir, exist_ok = True)

reindex = 0
for img_fid, seg_fid, pid in zip(imgs, segs, pids):

 img_obj = sitk.ReadImage( img_fid )
 seg_obj = sitk.ReadImage( seg_fid )

 array = sitk.GetArrayFromImage(img_obj)
 print(array.shape, f"label shape {sitk.GetArrayFromImage(seg_obj).shape}")
 array[array > HIR] = HIR
 array[array < LIR] = LIR
 
 array = (array - array.min()) / (array.max() - array.min()) * 255.0
 
 # then normalize this
 
 wined_img = sitk.GetImageFromArray(array)
 wined_img = copy_spacing_ori(img_obj, wined_img)
 
 out_img_fid = os.path.join( scan_dir, f'image_{str(reindex)}.nii.gz' )
 out_lb_fid = os.path.join( scan_dir, f'label_{str(reindex)}.nii.gz' ) 
 
 # then save
 sitk.WriteImage(wined_img, out_img_fid, True) 
 sitk.WriteImage(seg_obj, out_lb_fid, True) 
 print("{} has been save".format(out_img_fid))
 print("{} has been save".format(out_lb_fid))
 reindex += 1

(147, 512, 512) label shape (147, 512, 512)
./SABS/tmp_normalized/image_0.nii.gz has been save
./SABS/tmp_normalized/label_0.nii.gz has been save
(139, 512, 512) label shape (139, 512, 512)
./SABS/tmp_normalized/image_1.nii.gz has been save
./SABS/tmp_normalized/label_1.nii.gz has been save
(198, 512, 512) label shape (198, 512, 512)
./SABS/tmp_normalized/image_2.nii.gz has been save
./SABS/tmp_normalized/label_2.nii.gz has been save
(140, 512, 512) label shape (140, 512, 512)
./SABS/tmp_normalized/image_3.nii.gz has been save
./SABS/tmp_normalized/label_3.nii.gz has been save
(117, 512, 512) label shape (117, 512, 512)
./SABS/tmp_normalized/image_4.nii.gz has been save
./SABS/tmp_normalized/label_4.nii.gz has been save
(131, 512, 512) label shape (131, 512, 512)
./SABS/tmp_normalized/image_5.nii.gz has been save
./SABS/tmp_normalized/label_5.nii.gz has been save
(163, 512, 512) label shape (163, 512, 512)
./SABS/tmp_normalized/image_6.nii.gz has been save
./SABS/tmp_normalized/label_6

Overview

This is the second step of preprocessing

Cut out irrelevant empty boundary and resample to 512x512 in axial plane.

Input: intensity-normalized images

Output: spacially resampled images

In [9]:
IMG_FOLDER = "./SABS/tmp_normalized"

SEG_FOLDER = IMG_FOLDER
imgs = glob.glob(IMG_FOLDER + "/image_*.nii.gz")
imgs = sorted([ fid for fid in sorted(imgs) ])
segs = sorted([ fid for fid in glob.glob(SEG_FOLDER + "/label_*.nii.gz")])

pids = [pid.split("_")[-1].split(".")[0] for pid in imgs]
print(pids)
for img, seg in zip(imgs, segs):
 print(img, seg)

['0', '1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '4', '5', '6', '7', '8', '9']
./SABS/tmp_normalized/image_0.nii.gz ./SABS/tmp_normalized/label_0.nii.gz
./SABS/tmp_normalized/image_1.nii.gz ./SABS/tmp_normalized/label_1.nii.gz
./SABS/tmp_normalized/image_10.nii.gz ./SABS/tmp_normalized/label_10.nii.gz
./SABS/tmp_normalized/image_11.nii.gz ./SABS/tmp_normalized/label_11.nii.gz
./SABS/tmp_normalized/image_12.nii.gz ./SABS/tmp_normalized/label_12.nii.gz
./SABS/tmp_normalized/image_13.nii.gz ./SABS/tmp_normalized/label_13.nii.gz
./SABS/tmp_normalized/image_14.nii.gz ./SABS/tmp_normalized/label_14.nii.gz
./SABS/tmp_normalized/image_15.nii.gz ./SABS/tmp_normalized/label_15.nii.gz
./SABS/tmp_normalized/image_16.nii.gz ./SABS/tmp_normalized/label_16.nii.gz
./SABS/tmp_normalized/image_17.nii.gz ./SABS/tmp_normalized/label_17.nii.gz
./SABS/tmp_normalized/image_18.nii.gz ./SABS/tmp_normalized/label_18.nii.

In [10]:
import copy
OUT_FOLDER = "./SABS/sabs_CT_normalized"
BD_BIAS = 32 # cut irrelavent empty boundary to make roi stands out

# SPA_FAC = (512 - 2 * BD_BIAS) / 512 # spacing factor
for res in (256, 672):
 if res == 672:
 OUT_FOLDER += "_672"
 scan_dir = OUT_FOLDER
 os.makedirs(OUT_FOLDER, exist_ok = True)

 resample_imgs(imgs, segs, pids, scan_dir, BD_BIAS, SPA_FAC=None, required_res=res)

./SABS/tmp_normalized/image_0.nii.gz ./SABS/tmp_normalized/label_0.nii.gz
(147, 512, 512) label shape (147, 512, 512)
Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]
Size (448, 448, 147) -> [448. 448. 147.]
Label values: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13]
Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]
Size (448, 448, 147) -> [448. 448. 147.]
Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]
Size (448, 448, 147) -> [448. 448. 147.]
Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]
Size (448, 448, 147) -> [448. 448. 147.]
Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]
Size (448, 448, 147) -> [448. 448. 147.]
Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]
Size (448, 448, 147) -> [448. 448. 147.]
Spacing: (0.66796875, 0.66796875, 3.0) -> [0.66796875, 0.66796875, 3.0]
Size (448, 448, 147) -> [448. 448. 147.]
Spacing: (0.66796875, 0.66796875, 3.0) -> 

## Synapse Classmap Generation

In [12]:
import json
# import niftiio as nio
import SimpleITK as sitk

# normalization: cut top 2% of histogram, then doing volume-wise normalization
IMG_BNAMES = ("./SABS/sabs_CT_normalized/image_*.nii.gz", "./SABS/sabs_CT_normalized_672/image_*.nii.gz")
SEG_NAMES = ("./SABS/sabs_CT_normalized/label_*.nii.gz", "./SABS/sabs_CT_normalized_672/label_*.nii.gz")
for IMG_BNAME, SEG_BNAME in zip(IMG_BNAMES, SEG_NAMES):
 imgs = glob.glob(IMG_BNAME)
 segs = glob.glob(SEG_BNAME)
 imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) ) ]
 segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) ) ]
 for img, seg in zip(imgs, segs):
 print(img, seg)

 classmap = {}
 LABEL_NAME = ["BGD", "SPLEEN", "KID_R", "KID_l", "GALLBLADDER", "ESOPHAGUS", "LIVER", "STOMACH", "AORTA", "IVC", "PS_VEIN", "PANCREAS", "AG_R", "AG_L"] 

 MIN_TP=1 # minimum number of true positive pixels in a slice

 fid = os.path.dirname(IMG_BNAME) + f'/classmap_{MIN_TP}.json'
 for _lb in LABEL_NAME:
 classmap[_lb] = {}
 for pid in range(len(segs)):
 classmap[_lb][str(pid)] = []

 for pid, seg in enumerate(segs):
 # lb_vol = nio.read_nii_bysitk(seg)
 lb_vol = sitk.GetArrayFromImage(sitk.ReadImage(seg))
 n_slice = lb_vol.shape[0]
 for slc in range(n_slice):
 for cls in range(len(LABEL_NAME)):
 if cls in lb_vol[slc, ...]:
 if np.sum( lb_vol[slc, ...] == cls) >= MIN_TP:
 classmap[LABEL_NAME[cls]][str(pid)].append(slc)
 print(f'pid {str(pid)} finished!')
 
 with open(fid, 'w') as fopen:
 json.dump(classmap, fopen)
 fopen.close() 
 

pid 0 finished!
pid 1 finished!
pid 2 finished!
pid 3 finished!
pid 4 finished!
pid 5 finished!
pid 6 finished!
pid 7 finished!
pid 8 finished!
pid 9 finished!
pid 10 finished!
pid 11 finished!
pid 12 finished!
pid 13 finished!
pid 14 finished!
pid 15 finished!
pid 16 finished!
pid 17 finished!
pid 18 finished!
pid 19 finished!
pid 20 finished!
pid 21 finished!
pid 22 finished!
pid 23 finished!
pid 24 finished!
pid 25 finished!
pid 26 finished!
pid 27 finished!
pid 28 finished!
pid 29 finished!
pid 30 finished!
pid 31 finished!
pid 32 finished!
pid 33 finished!
pid 34 finished!
pid 35 finished!
pid 36 finished!
pid 37 finished!


# MRI Image Normalization

In [None]:
## PLEASE RUN dcm_img_to_nii.sh to convert dicom to nii.gz
! ./dcm_img_to_nii.sh

In [99]:
import copy

IMG_FOLDER = "./CHAOST2/niis/T2SPIR" #, path of nii-like images from step 1
OUT_FOLDER="./CHAOST2/chaos_MR_T2_normalized/" # output directory

imgs = glob.glob(IMG_FOLDER + f'/image_*.nii.gz')
imgs = [ fid for fid in sorted(imgs) ]
segs = [ fid for fid in sorted(glob.glob(IMG_FOLDER + f'/label_*.nii.gz')) ]

pids = [pid.split("_")[-1].split(".")[0] for pid in imgs]
for img, seg in zip(imgs, segs):
 print(img, seg)

os.makedirs(OUT_FOLDER, exist_ok = True)
 
HIST_CUT_TOP = 0.5 # cut top 0.5% of intensity historgam to alleviate off-resonance effect

NEW_SPA = [1.25, 1.25, 7.70] # unified voxel spacing

for img_fid, seg_fid, pid in zip(imgs, segs, pids):

 resample_flg = True

 img_obj = sitk.ReadImage( img_fid )
 seg_obj = sitk.ReadImage( seg_fid )

 array = sitk.GetArrayFromImage(img_obj)

 # cut histogram
 hir = float(np.percentile(array, 100.0 - HIST_CUT_TOP))
 array[array > hir] = hir

 his_img_o = sitk.GetImageFromArray(array)
 his_img_o = copy_spacing_ori(img_obj, his_img_o)

 # resampling
 img_spa_ori = img_obj.GetSpacing()
 res_img_o = resample_by_res(his_img_o, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2]],
 interpolator = sitk.sitkLinear, logging = True)
 ## label
 lb_arr = sitk.GetArrayFromImage(seg_obj)

 # resampling
 res_lb_o = resample_lb_by_res(seg_obj, [NEW_SPA[0], NEW_SPA[1], NEW_SPA[2] ], interpolator = sitk.sitkLinear,
 ref_img = None, logging = True)

 # crop out rois
 res_img_a = s2n(res_img_o)

 crop_img_a = image_crop(res_img_a.transpose(1,2,0), [256, 256],
 referece_ctr_idx = [res_img_a.shape[1] // 2, res_img_a.shape[2] //2],
 padval = res_img_a.min(), only_2d = True).transpose(2,0,1)

 out_img_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_img_a))

 res_lb_a = s2n(res_lb_o)

 crop_lb_a = image_crop(res_lb_a.transpose(1,2,0), [256, 256],
 referece_ctr_idx = [res_lb_a.shape[1] // 2, res_lb_a.shape[2] //2],
 padval = 0, only_2d = True).transpose(2,0,1)

 out_lb_obj = copy_spacing_ori(res_img_o, sitk.GetImageFromArray(crop_lb_a))


 out_img_fid = os.path.join( OUT_FOLDER, f'image_{pid}.nii.gz' )
 out_lb_fid = os.path.join( OUT_FOLDER, f'label_{pid}.nii.gz' ) 

 # then save pre-processed images
 sitk.WriteImage(out_img_obj, out_img_fid, True) 
 sitk.WriteImage(out_lb_obj, out_lb_fid, True) 
 print("{} has been saved".format(out_img_fid))

Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]
Size (256, 256, 36) -> [316. 316. 35.99999911]
Label values: [0 1 2 3 4]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]
Size (256, 256, 36) -> [316. 316. 35.99999911]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]
Size (256, 256, 36) -> [316. 316. 35.99999911]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]
Size (256, 256, 36) -> [316. 316. 35.99999911]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]
Size (256, 256, 36) -> [316. 316. 35.99999911]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [1.25, 1.25, 7.7]
Size (256, 256, 36) -> [316. 316. 35.99999911]
./CHAOST2/chaos_MR_T2_normalized/image_1.nii.gz has been saved
Spacing: (1.69921875, 1.69921875, 7.699999809265137) -> [1.25, 1.25, 7.7]
Size (256, 256, 36) -> [348. 348. 35.99999911]
Label values: [0 1 2 3 4]
Spacing: (1.69921875, 1.69921875, 7.69

## MRI Resampling and ROI

In [111]:
# SPA_FAC = (256 - 2 * BD_BIAS) / 512 # spacing factor
BD_BIAS = 1
scan_dir = OUT_FOLDER
for res in (256, 672):
 if res == 672:
 scan_dir += "_672"
 resample_imgs(imgs, segs, pids, scan_dir,
 BD_BIAS, SPA_FAC=None, required_res=res)

./CHAOST2/niis/T2SPIR/image_1.nii.gz ./CHAOST2/niis/T2SPIR/label_1.nii.gz
(36, 256, 256) label shape (36, 256, 256)
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]
Size (254, 254, 36) -> [672. 672. 36.]
Label values: [0 1 2 3 4]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]
Size (254, 254, 36) -> [672. 672. 36.]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]
Size (254, 254, 36) -> [672. 672. 36.]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]
Size (254, 254, 36) -> [672. 672. 36.]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832054501488095, 7.699999809265137]
Size (254, 254, 36) -> [672. 672. 36.]
Spacing: (1.54296875, 1.54296875, 7.699999809265137) -> [0.5832054501488095, 0.5832

## MRI Classmap Generation

In [89]:
IMG_BNAMES = ("./CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz", "./CHAOST2/chaos_MR_T2_normalized_672/image_*.nii.gz")
SEG_NAMES = ("./CHAOST2/chaos_MR_T2_normalized/label_*.nii.gz", "./CHAOST2/chaos_MR_T2_normalized_672/label_*.nii.gz")

for IMG_BNAME, SEG_BNAME in zip(IMG_BNAMES, SEG_NAMES):
 imgs = glob.glob(IMG_BNAME)
 segs = glob.glob(SEG_BNAME)
 imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) ) ]
 segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0]) ) ]


 classmap = {}
 LABEL_NAME = ["BG", "LIVER", "RK", "LK", "SPLEEN"] 

 MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training

 fid = os.path.join(OUT_FOLDER,f'.classmap_{MIN_TP}.json') # name of the output file. 
 for _lb in LABEL_NAME:
 classmap[_lb] = {}
 for _sid in segs:
 pid = _sid.split("_")[-1].split(".nii.gz")[0]
 classmap[_lb][pid] = []

 for seg in segs:
 pid = seg.split("_")[-1].split(".nii.gz")[0]
 lb_vol = sitk.GetArrayFromImage(sitk.ReadImage(seg))
 n_slice = lb_vol.shape[0]
 for slc in range(n_slice):
 for cls in range(len(LABEL_NAME)):
 if cls in lb_vol[slc, ...]:
 if np.sum( lb_vol[slc, ...]) >= MIN_TP:
 classmap[LABEL_NAME[cls]][str(pid)].append(slc)
 print(f'pid {str(pid)} finished!')
 
 with open(fid, 'w') as fopen:
 json.dump(classmap, fopen)
 fopen.close() 

 

pid 1 finished!
pid 2 finished!
pid 3 finished!
pid 5 finished!
pid 8 finished!
pid 10 finished!
pid 13 finished!
pid 15 finished!
pid 19 finished!
pid 20 finished!
pid 21 finished!
pid 22 finished!
pid 31 finished!
pid 32 finished!
pid 33 finished!
pid 34 finished!
pid 36 finished!
pid 37 finished!
pid 38 finished!
pid 39 finished!


# Psuedo label generation for Encoder Finetuning

In [90]:
import matplotlib.pyplot as plt
import copy
import skimage

from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.measure import label 
import scipy.ndimage.morphology as snm
from skimage import io
import argparse


to01 = lambda x: (x - x.min()) / (x.max() - x.min())

Summary

a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background

b. Generate superpixels as pseudolabels

Configurations of pseudlabels

default setting of minimum superpixel sizes
`segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)`

you can also try other configs
`segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)`

In [91]:
MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting

# wrapper for process 3d image in 2d
def superpix_vol(img, method = 'fezlen', **kwargs):
 """
 loop through the entire volume
 assuming image with axis z, x, y
 """
 if method =='fezlen':
 seg_func = skimage.segmentation.felzenszwalb
 else:
 raise NotImplementedError
 
 out_vol = np.zeros(img.shape)
 for ii in range(img.shape[0]):
 if MODE == 'MIDDLE':
 segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)
 else:
 raise NotImplementedError
 out_vol[ii, ...] = segs
 
 return out_vol

# thresholding the intensity values to get a binary mask of the patient
def fg_mask2d(img_2d, thresh): # change this by your need
 mask_map = np.float32(img_2d > thresh)
 
 def getLargestCC(segmentation): # largest connected components
 labels = label(segmentation)
 assert( labels.max() != 0 ) # assume at least 1 CC
 largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
 return largestCC
 if mask_map.max() < 0.999:
 return mask_map
 else:
 post_mask = getLargestCC(mask_map)
 fill_mask = snm.binary_fill_holes(post_mask)
 return fill_mask

# remove superpixels within the empty regions
def superpix_masking(raw_seg2d, mask2d):
 raw_seg2d = np.int32(raw_seg2d)
 lbvs = np.unique(raw_seg2d)
 max_lb = lbvs.max()
 raw_seg2d[raw_seg2d == 0] = max_lb + 1
 lbvs = list(lbvs)
 lbvs.append( max_lb )
 raw_seg2d = raw_seg2d * mask2d
 lb_new = 1
 out_seg2d = np.zeros(raw_seg2d.shape)
 for lbv in lbvs:
 if lbv == 0:
 continue
 else:
 out_seg2d[raw_seg2d == lbv] = lb_new
 lb_new += 1
 
 return out_seg2d
 
def superpix_wrapper(img, verbose = False, fg_thresh = 1e-4):
 raw_seg = superpix_vol(img)
 fg_mask_vol = np.zeros(raw_seg.shape)
 processed_seg_vol = np.zeros(raw_seg.shape)
 for ii in range(raw_seg.shape[0]):
 if verbose:
 print("doing {} slice".format(ii))
 _fgm = fg_mask2d(img[ii, ...], fg_thresh )
 _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)
 fg_mask_vol[ii] = _fgm
 processed_seg_vol[ii] = _out_seg
 return fg_mask_vol, processed_seg_vol
 
# copy spacing and orientation info between sitk objects
def copy_info(src, dst):
 dst.SetSpacing(src.GetSpacing())
 dst.SetOrigin(src.GetOrigin())
 dst.SetDirection(src.GetDirection())
 # dst.CopyInfomation(src)
 return dst


def strip_(img, lb):
 img = np.int32(img)
 if isinstance(lb, float):
 lb = int(lb)
 return np.float32(img == lb) * float(lb)
 elif isinstance(lb, list):
 out = np.zeros(img.shape)
 for _lb in lb:
 out += np.float32(img == int(_lb)) * float(_lb)
 
 return out
 else:
 raise Exception

In [None]:
DATASET_CONFIG = {'SABS':{
 'img_bname': f'./SABS/sabs_CT_normalized/image_*.nii.gz',
 'out_dir': './SABS/sabs_CT_normalized',
 'fg_thresh': 1e-4
 },
 'CHAOST2':{
 'img_bname': f'./CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz',
 'out_dir': './CHAOST2/chaos_MR_T2_normalized',
 'fg_thresh': 1e-4 + 50
 },
 'SABS_672':{
 'img_bname': f'./SABS/sabs_CT_normalized_672/image_*.nii.gz',
 'out_dir': './SABS/sabs_CT_normalized_672',
 'fg_thresh': 1e-4
 },
 'CHAOST2_672':{
 'img_bname': f'./CHAOST2/chaos_MR_T2_normalized_672/image_*.nii.gz',
 'out_dir': './CHAOST2/chaos_MR_T2_normalized_672',
 'fg_thresh': 1e-4 + 50
 }
}

for DOMAIN in DATASET_CONFIG.keys():
 img_bname = DATASET_CONFIG[DOMAIN]['img_bname']
 imgs = glob.glob(img_bname)
 out_dir = DATASET_CONFIG[DOMAIN]['out_dir']

 imgs = sorted(imgs, key = lambda x: int(x.split('_')[-1].split('.nii.gz')[0]) )
 print(imgs)

 # Generate pseudolabels for every image and save them
 for img_fid in imgs:
 # img_fid = imgs[0]

 idx = os.path.basename(img_fid).split("_")[-1].split(".nii.gz")[0]
 im_obj = sitk.ReadImage(img_fid)

 out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj), fg_thresh = DATASET_CONFIG[DOMAIN]['fg_thresh'] )
 out_fg_o = sitk.GetImageFromArray(out_fg ) 
 out_seg_o = sitk.GetImageFromArray(out_seg )

 out_fg_o = copy_info(im_obj, out_fg_o)
 out_seg_o = copy_info(im_obj, out_seg_o)
 seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')
 msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')
 sitk.WriteImage(out_fg_o, msk_fid)
 sitk.WriteImage(out_seg_o, seg_fid)
 print(f'image with id {idx} has finished')
