import os import json import onnxruntime from huggingface_hub import hf_hub_download import logging def load_onnx_model_and_preprocessor(hf_model_id): model_specific_dir = os.path.join("./models", hf_model_id.replace('/', '_')) os.makedirs(model_specific_dir, exist_ok=True) onnx_model_path = hf_hub_download(repo_id=hf_model_id, filename="model_quantized.onnx", subfolder="onnx", local_dir=model_specific_dir, local_dir_use_symlinks=False) preprocessor_config = {} try: preprocessor_config_path = hf_hub_download(repo_id=hf_model_id, filename="preprocessor_config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False) with open(preprocessor_config_path, 'r') as f: preprocessor_config = json.load(f) except Exception as e: logging.getLogger(__name__).warning(f"Could not download or load preprocessor_config.json for {hf_model_id}: {e}") model_config = {} try: model_config_path = hf_hub_download(repo_id=hf_model_id, filename="config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False) with open(model_config_path, 'r') as f: model_config = json.load(f) except Exception as e: logging.getLogger(__name__).warning(f"Could not download or load config.json for {hf_model_id}: {e}") return onnxruntime.InferenceSession(onnx_model_path), preprocessor_config, model_config def get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_func): import logging logger = logging.getLogger(__name__) if hf_model_id not in _onnx_model_cache: logger.info(f"Loading ONNX model and preprocessor for {hf_model_id}...") _onnx_model_cache[hf_model_id] = load_func(hf_model_id) return _onnx_model_cache[hf_model_id]