File size: 1,425 Bytes
01f75cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import pickle

import numpy as np
import torch

# このファイル(postprocessing.py)のあるディレクトリを取得
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
# split_map.pkl は同じディレクトリ内にあるため、CURRENT_DIR を基にパスを作成
SPLIT_MAP_PATH = os.path.join(CURRENT_DIR, "split_map.pkl")


def postprocessing(parcellated, separated, shift, device):
    # 絶対パスを用いて split_map.pkl を読み込む
    with open(SPLIT_MAP_PATH, "rb") as tf:
        dictionary = pickle.load(tf)

    pmap = torch.tensor(parcellated.astype("int16"), requires_grad=False).to(device)
    hmap = torch.tensor(separated.astype("int16"), requires_grad=False).to(device)
    combined = torch.stack((torch.flatten(hmap), torch.flatten(pmap)), axis=-1)
    output = torch.zeros_like(hmap).ravel()
    for key, value in dictionary.items():
        key = torch.tensor(key, requires_grad=False).to(device)
        mask = torch.all(combined == key, axis=1)
        output[mask] = value
    output = output.reshape(hmap.shape)
    output = output.cpu().detach().numpy()
    output = output * (
        np.logical_or(np.logical_or(separated > 0, parcellated == 87), parcellated == 138)
    )
    output = np.pad(output, [(32, 32), (16, 16), (32, 32)], "constant", constant_values=0)
    output = np.roll(output, (-shift[0], -shift[1], -shift[2]), axis=(0, 1, 2))
    return output