|
import os |
|
import json |
|
import onnxruntime |
|
from huggingface_hub import hf_hub_download |
|
import logging |
|
|
|
def load_onnx_model_and_preprocessor(hf_model_id): |
|
model_specific_dir = os.path.join("./models", hf_model_id.replace('/', '_')) |
|
os.makedirs(model_specific_dir, exist_ok=True) |
|
onnx_model_path = hf_hub_download(repo_id=hf_model_id, filename="model_quantized.onnx", subfolder="onnx", local_dir=model_specific_dir, local_dir_use_symlinks=False) |
|
preprocessor_config = {} |
|
try: |
|
preprocessor_config_path = hf_hub_download(repo_id=hf_model_id, filename="preprocessor_config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False) |
|
with open(preprocessor_config_path, 'r') as f: |
|
preprocessor_config = json.load(f) |
|
except Exception as e: |
|
logging.getLogger(__name__).warning(f"Could not download or load preprocessor_config.json for {hf_model_id}: {e}") |
|
model_config = {} |
|
try: |
|
model_config_path = hf_hub_download(repo_id=hf_model_id, filename="config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False) |
|
with open(model_config_path, 'r') as f: |
|
model_config = json.load(f) |
|
except Exception as e: |
|
logging.getLogger(__name__).warning(f"Could not download or load config.json for {hf_model_id}: {e}") |
|
return onnxruntime.InferenceSession(onnx_model_path), preprocessor_config, model_config |
|
|
|
def get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_func): |
|
import logging |
|
logger = logging.getLogger(__name__) |
|
if hf_model_id not in _onnx_model_cache: |
|
logger.info(f"Loading ONNX model and preprocessor for {hf_model_id}...") |
|
_onnx_model_cache[hf_model_id] = load_func(hf_model_id) |
|
return _onnx_model_cache[hf_model_id] |
|
|