OpenMAP-T1 / src /utils /postprocessing.py
西牧慧
add: first commit
01f75cf
import os
import pickle
import numpy as np
import torch
# このファイル(postprocessing.py)のあるディレクトリを取得
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
# split_map.pkl は同じディレクトリ内にあるため、CURRENT_DIR を基にパスを作成
SPLIT_MAP_PATH = os.path.join(CURRENT_DIR, "split_map.pkl")
def postprocessing(parcellated, separated, shift, device):
# 絶対パスを用いて split_map.pkl を読み込む
with open(SPLIT_MAP_PATH, "rb") as tf:
dictionary = pickle.load(tf)
pmap = torch.tensor(parcellated.astype("int16"), requires_grad=False).to(device)
hmap = torch.tensor(separated.astype("int16"), requires_grad=False).to(device)
combined = torch.stack((torch.flatten(hmap), torch.flatten(pmap)), axis=-1)
output = torch.zeros_like(hmap).ravel()
for key, value in dictionary.items():
key = torch.tensor(key, requires_grad=False).to(device)
mask = torch.all(combined == key, axis=1)
output[mask] = value
output = output.reshape(hmap.shape)
output = output.cpu().detach().numpy()
output = output * (
np.logical_or(np.logical_or(separated > 0, parcellated == 87), parcellated == 138)
)
output = np.pad(output, [(32, 32), (16, 16), (32, 32)], "constant", constant_values=0)
output = np.roll(output, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2))
return output