LPX55 commited on
Commit
80c3f0c
·
1 Parent(s): 4cb6734

refactor: ONNX model loading and caching functions

Browse files
Files changed (2) hide show
  1. app.py +1 -36
  2. utils/onnx_model_loader.py +33 -0
app.py CHANGED
@@ -17,6 +17,7 @@ import torch
17
 
18
  from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
19
  from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output
 
20
  from forensics.gradient import gradient_processing
21
  from forensics.minmax import minmax_process
22
  from forensics.ela import ELA
@@ -97,46 +98,10 @@ def register_model_with_metadata(model_id, model, preprocess, postprocess, class
97
  MODEL_REGISTRY[model_id] = entry
98
 
99
 
100
- def load_onnx_model_and_preprocessor(hf_model_id):
101
- # model_dir = snapshot_download(repo_id=hf_model_id, local_dir_use_symlinks=False)
102
-
103
- # Create a unique local directory for each ONNX model
104
- model_specific_dir = os.path.join("./models", hf_model_id.replace('/', '_'))
105
- os.makedirs(model_specific_dir, exist_ok=True)
106
-
107
- # Use hf_hub_download to get specific files into the model-specific directory
108
- onnx_model_path = hf_hub_download(repo_id=hf_model_id, filename="model_quantized.onnx", subfolder="onnx", local_dir=model_specific_dir, local_dir_use_symlinks=False)
109
-
110
- # Load preprocessor config
111
- preprocessor_config = {}
112
- try:
113
- preprocessor_config_path = hf_hub_download(repo_id=hf_model_id, filename="preprocessor_config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False)
114
- with open(preprocessor_config_path, 'r') as f:
115
- preprocessor_config = json.load(f)
116
- except Exception as e:
117
- logger.warning(f"Could not download or load preprocessor_config.json for {hf_model_id}: {e}")
118
-
119
- # Load model config for class names if available
120
- model_config = {}
121
- try:
122
- model_config_path = hf_hub_download(repo_id=hf_model_id, filename="config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False)
123
- with open(model_config_path, 'r') as f:
124
- model_config = json.load(f)
125
- except Exception as e:
126
- logger.warning(f"Could not download or load config.json for {hf_model_id}: {e}")
127
-
128
- return onnxruntime.InferenceSession(onnx_model_path), preprocessor_config, model_config
129
-
130
 
131
  # Cache for ONNX sessions and preprocessors
132
  _onnx_model_cache = {}
133
 
134
- def get_onnx_model_from_cache(hf_model_id):
135
- if hf_model_id not in _onnx_model_cache:
136
- logger.info(f"Loading ONNX model and preprocessor for {hf_model_id}...")
137
- _onnx_model_cache[hf_model_id] = load_onnx_model_and_preprocessor(hf_model_id)
138
- return _onnx_model_cache[hf_model_id]
139
-
140
 
141
  def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
142
  try:
 
17
 
18
  from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
19
  from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output
20
+ from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
21
  from forensics.gradient import gradient_processing
22
  from forensics.minmax import minmax_process
23
  from forensics.ela import ELA
 
98
  MODEL_REGISTRY[model_id] = entry
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # Cache for ONNX sessions and preprocessors
103
  _onnx_model_cache = {}
104
 
 
 
 
 
 
 
105
 
106
  def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
107
  try:
utils/onnx_model_loader.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import onnxruntime
4
+ from huggingface_hub import hf_hub_download
5
+ import logging
6
+
7
+ def load_onnx_model_and_preprocessor(hf_model_id):
8
+ model_specific_dir = os.path.join("./models", hf_model_id.replace('/', '_'))
9
+ os.makedirs(model_specific_dir, exist_ok=True)
10
+ onnx_model_path = hf_hub_download(repo_id=hf_model_id, filename="model_quantized.onnx", subfolder="onnx", local_dir=model_specific_dir, local_dir_use_symlinks=False)
11
+ preprocessor_config = {}
12
+ try:
13
+ preprocessor_config_path = hf_hub_download(repo_id=hf_model_id, filename="preprocessor_config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False)
14
+ with open(preprocessor_config_path, 'r') as f:
15
+ preprocessor_config = json.load(f)
16
+ except Exception as e:
17
+ logging.getLogger(__name__).warning(f"Could not download or load preprocessor_config.json for {hf_model_id}: {e}")
18
+ model_config = {}
19
+ try:
20
+ model_config_path = hf_hub_download(repo_id=hf_model_id, filename="config.json", local_dir=model_specific_dir, local_dir_use_symlinks=False)
21
+ with open(model_config_path, 'r') as f:
22
+ model_config = json.load(f)
23
+ except Exception as e:
24
+ logging.getLogger(__name__).warning(f"Could not download or load config.json for {hf_model_id}: {e}")
25
+ return onnxruntime.InferenceSession(onnx_model_path), preprocessor_config, model_config
26
+
27
+ def get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_func):
28
+ import logging
29
+ logger = logging.getLogger(__name__)
30
+ if hf_model_id not in _onnx_model_cache:
31
+ logger.info(f"Loading ONNX model and preprocessor for {hf_model_id}...")
32
+ _onnx_model_cache[hf_model_id] = load_func(hf_model_id)
33
+ return _onnx_model_cache[hf_model_id]