OpenMAP-T1 / src /utils /load_model.py
西牧慧
update: parcellation
dcbe128
import os
import torch
from utils.network import UNet
def load_model(model_dir, device):
"""
This function loads multiple pre-trained models and sets them to evaluation mode.
The models loaded are:
1. CNet: A U-Net model for some specific task.
2. SSNet: Another U-Net model for a different task.
3. PNet coronal: A U-Net model for coronal plane predictions.
4. PNet sagittal: A U-Net model for sagittal plane predictions.
5. PNet axial: A U-Net model for axial plane predictions.
6. HNet coronal: A U-Net model for coronal plane predictions with different input/output channels.
7. HNet axial: A U-Net model for axial plane predictions with different input/output channels.
Parameters:
opt (object): An options object containing model paths.
device (torch.device): The device on which to load the models (CPU or GPU).
Returns:
tuple: A tuple containing all the loaded models.
"""
# Load CNet model
cnet = UNet(1, 1)
cnet.load_state_dict(torch.load(os.path.join(model_dir, "CNet", "CNet.pth"), weights_only=True))
cnet.to(device)
cnet.eval()
# Load SSNet model
ssnet = UNet(1, 1)
ssnet.load_state_dict(
torch.load(os.path.join(model_dir, "SSNet", "SSNet.pth"), weights_only=True)
)
ssnet.to(device)
ssnet.eval()
# Load PNet coronal model
# pnet_c = UNet(3, 142)
# pnet_c.load_state_dict(
# torch.load(os.path.join(model_dir, "PNet", "coronal.pth"), weights_only=True)
# )
# pnet_c.to(device)
# pnet_c.eval()
# Load PNet sagittal model
# pnet_s = UNet(3, 142)
# pnet_s.load_state_dict(
# torch.load(os.path.join(model_dir, "PNet", "sagittal.pth"), weights_only=True)
# )
# pnet_s.to(device)
# pnet_s.eval()
# Load PNet axial model
pnet_a = UNet(3, 142)
pnet_a.load_state_dict(
torch.load(os.path.join(model_dir, "PNet", "axial.pth"), weights_only=True)
)
pnet_a.to(device)
pnet_a.eval()
# Load HNet coronal model
hnet_c = UNet(1, 3)
hnet_c.load_state_dict(
torch.load(os.path.join(model_dir, "HNet", "coronal.pth"), weights_only=True)
)
hnet_c.to(device)
hnet_c.eval()
# Load HNet axial model
hnet_a = UNet(1, 3)
hnet_a.load_state_dict(
torch.load(os.path.join(model_dir, "HNet", "axial.pth"), weights_only=True)
)
hnet_a.to(device)
hnet_a.eval()
# Return all loaded models
# return cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a
return cnet, ssnet, pnet_a, hnet_c, hnet_a