LPX55
refactor: ONNX model loading and caching functions
80c3f0c
import os
import json
import onnxruntime
from huggingface_hub import hf_hub_download
import logging
def load_onnx_model_and_preprocessor(hf_model_id):
model_specific_dir = os.path.join("./models", hf_model_id.replace('/', '_'))
os.makedirs(model_specific_dir, exist_ok=True)
onnx_model_path = hf_hub_download(repo_id=hf_model_id, filename="model_quantized.onnx", subfolder="onnx", local_dir=model_specific_dir, local_dir_use_symlinks=False)
preprocessor_config = {}
try:
preprocessor_config_path = hf_hub_download(repo_id=hf_model_id, filename="preprocessor_config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False)
with open(preprocessor_config_path, 'r') as f:
preprocessor_config = json.load(f)
except Exception as e:
logging.getLogger(__name__).warning(f"Could not download or load preprocessor_config.json for {hf_model_id}: {e}")
model_config = {}
try:
model_config_path = hf_hub_download(repo_id=hf_model_id, filename="config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False)
with open(model_config_path, 'r') as f:
model_config = json.load(f)
except Exception as e:
logging.getLogger(__name__).warning(f"Could not download or load config.json for {hf_model_id}: {e}")
return onnxruntime.InferenceSession(onnx_model_path), preprocessor_config, model_config
def get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_func):
import logging
logger = logging.getLogger(__name__)
if hf_model_id not in _onnx_model_cache:
logger.info(f"Loading ONNX model and preprocessor for {hf_model_id}...")
_onnx_model_cache[hf_model_id] = load_func(hf_model_id)
return _onnx_model_cache[hf_model_id]