import torch from models.ram import RAM from huggingface_hub import hf_hub_download device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_model(): """ Load the model. :param str model_name: name of the model :param str device: device :param bool grayscale: if True, the model is trained on grayscale images :param bool train: if True, the model is trained :return: model """ model = RAM() model.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device)) return model