File size: 1,777 Bytes
80c3f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import json
import onnxruntime
from huggingface_hub import hf_hub_download
import logging

def load_onnx_model_and_preprocessor(hf_model_id):
    model_specific_dir = os.path.join("./models", hf_model_id.replace('/', '_'))
    os.makedirs(model_specific_dir, exist_ok=True)
    onnx_model_path = hf_hub_download(repo_id=hf_model_id, filename="model_quantized.onnx", subfolder="onnx", local_dir=model_specific_dir, local_dir_use_symlinks=False)
    preprocessor_config = {}
    try:
        preprocessor_config_path = hf_hub_download(repo_id=hf_model_id, filename="preprocessor_config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False)
        with open(preprocessor_config_path, 'r') as f:
            preprocessor_config = json.load(f)
    except Exception as e:
        logging.getLogger(__name__).warning(f"Could not download or load preprocessor_config.json for {hf_model_id}: {e}")
    model_config = {}
    try:
        model_config_path = hf_hub_download(repo_id=hf_model_id, filename="config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False)
        with open(model_config_path, 'r') as f:
            model_config = json.load(f)
    except Exception as e:
        logging.getLogger(__name__).warning(f"Could not download or load config.json for {hf_model_id}: {e}")
    return onnxruntime.InferenceSession(onnx_model_path), preprocessor_config, model_config

def get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_func):
    import logging
    logger = logging.getLogger(__name__)
    if hf_model_id not in _onnx_model_cache:
        logger.info(f"Loading ONNX model and preprocessor for {hf_model_id}...")
        _onnx_model_cache[hf_model_id] = load_func(hf_model_id)
    return _onnx_model_cache[hf_model_id]