import torch from models.ram import RAM def get_model(): """ Load the model. :param str model_name: name of the model :param str device: device :param bool grayscale: if True, the model is trained on grayscale images :param bool train: if True, the model is trained :return: model """ model = RAM() state_dict = torch.load('ckpt/ram.pth.tar') model.load_state_dict(state_dict) return model