Spaces:
Sleeping
Sleeping
File size: 1,425 Bytes
01f75cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import os
import pickle
import numpy as np
import torch
# このファイル(postprocessing.py)のあるディレクトリを取得
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
# split_map.pkl は同じディレクトリ内にあるため、CURRENT_DIR を基にパスを作成
SPLIT_MAP_PATH = os.path.join(CURRENT_DIR, "split_map.pkl")
def postprocessing(parcellated, separated, shift, device):
# 絶対パスを用いて split_map.pkl を読み込む
with open(SPLIT_MAP_PATH, "rb") as tf:
dictionary = pickle.load(tf)
pmap = torch.tensor(parcellated.astype("int16"), requires_grad=False).to(device)
hmap = torch.tensor(separated.astype("int16"), requires_grad=False).to(device)
combined = torch.stack((torch.flatten(hmap), torch.flatten(pmap)), axis=-1)
output = torch.zeros_like(hmap).ravel()
for key, value in dictionary.items():
key = torch.tensor(key, requires_grad=False).to(device)
mask = torch.all(combined == key, axis=1)
output[mask] = value
output = output.reshape(hmap.shape)
output = output.cpu().detach().numpy()
output = output * (
np.logical_or(np.logical_or(separated > 0, parcellated == 87), parcellated == 138)
)
output = np.pad(output, [(32, 32), (16, 16), (32, 32)], "constant", constant_values=0)
output = np.roll(output, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2))
return output
|