Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import numpy as np | |
import torch | |
# このファイル(postprocessing.py)のあるディレクトリを取得 | |
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
# split_map.pkl は同じディレクトリ内にあるため、CURRENT_DIR を基にパスを作成 | |
SPLIT_MAP_PATH = os.path.join(CURRENT_DIR, "split_map.pkl") | |
def postprocessing(parcellated, separated, shift, device): | |
# 絶対パスを用いて split_map.pkl を読み込む | |
with open(SPLIT_MAP_PATH, "rb") as tf: | |
dictionary = pickle.load(tf) | |
pmap = torch.tensor(parcellated.astype("int16"), requires_grad=False).to(device) | |
hmap = torch.tensor(separated.astype("int16"), requires_grad=False).to(device) | |
combined = torch.stack((torch.flatten(hmap), torch.flatten(pmap)), axis=-1) | |
output = torch.zeros_like(hmap).ravel() | |
for key, value in dictionary.items(): | |
key = torch.tensor(key, requires_grad=False).to(device) | |
mask = torch.all(combined == key, axis=1) | |
output[mask] = value | |
output = output.reshape(hmap.shape) | |
output = output.cpu().detach().numpy() | |
output = output * ( | |
np.logical_or(np.logical_or(separated > 0, parcellated == 87), parcellated == 138) | |
) | |
output = np.pad(output, [(32, 32), (16, 16), (32, 32)], "constant", constant_values=0) | |
output = np.roll(output, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2)) | |
return output | |