Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from utils.network import UNet | |
def load_model(model_dir, device): | |
""" | |
This function loads multiple pre-trained models and sets them to evaluation mode. | |
The models loaded are: | |
1. CNet: A U-Net model for some specific task. | |
2. SSNet: Another U-Net model for a different task. | |
3. PNet coronal: A U-Net model for coronal plane predictions. | |
4. PNet sagittal: A U-Net model for sagittal plane predictions. | |
5. PNet axial: A U-Net model for axial plane predictions. | |
6. HNet coronal: A U-Net model for coronal plane predictions with different input/output channels. | |
7. HNet axial: A U-Net model for axial plane predictions with different input/output channels. | |
Parameters: | |
opt (object): An options object containing model paths. | |
device (torch.device): The device on which to load the models (CPU or GPU). | |
Returns: | |
tuple: A tuple containing all the loaded models. | |
""" | |
# Load CNet model | |
cnet = UNet(1, 1) | |
cnet.load_state_dict(torch.load(os.path.join(model_dir, "CNet", "CNet.pth"), weights_only=True)) | |
cnet.to(device) | |
cnet.eval() | |
# Load SSNet model | |
ssnet = UNet(1, 1) | |
ssnet.load_state_dict( | |
torch.load(os.path.join(model_dir, "SSNet", "SSNet.pth"), weights_only=True) | |
) | |
ssnet.to(device) | |
ssnet.eval() | |
# Load PNet coronal model | |
# pnet_c = UNet(3, 142) | |
# pnet_c.load_state_dict( | |
# torch.load(os.path.join(model_dir, "PNet", "coronal.pth"), weights_only=True) | |
# ) | |
# pnet_c.to(device) | |
# pnet_c.eval() | |
# Load PNet sagittal model | |
# pnet_s = UNet(3, 142) | |
# pnet_s.load_state_dict( | |
# torch.load(os.path.join(model_dir, "PNet", "sagittal.pth"), weights_only=True) | |
# ) | |
# pnet_s.to(device) | |
# pnet_s.eval() | |
# Load PNet axial model | |
pnet_a = UNet(3, 142) | |
pnet_a.load_state_dict( | |
torch.load(os.path.join(model_dir, "PNet", "axial.pth"), weights_only=True) | |
) | |
pnet_a.to(device) | |
pnet_a.eval() | |
# Load HNet coronal model | |
hnet_c = UNet(1, 3) | |
hnet_c.load_state_dict( | |
torch.load(os.path.join(model_dir, "HNet", "coronal.pth"), weights_only=True) | |
) | |
hnet_c.to(device) | |
hnet_c.eval() | |
# Load HNet axial model | |
hnet_a = UNet(1, 3) | |
hnet_a.load_state_dict( | |
torch.load(os.path.join(model_dir, "HNet", "axial.pth"), weights_only=True) | |
) | |
hnet_a.to(device) | |
hnet_a.eval() | |
# Return all loaded models | |
# return cnet, ssnet, pnet_c, pnet_s, pnet_a, hnet_c, hnet_a | |
return cnet, ssnet, pnet_a, hnet_c, hnet_a | |