LPX55 commited on
Commit
ceff40b
·
1 Parent(s): 3232315

fix: add missing import for infer_onnx_model in model_loader and update its definition in onnx_helpers

Browse files
Files changed (3) hide show
  1. app.py +1 -29
  2. utils/model_loader.py +1 -1
  3. utils/onnx_helpers.py +22 -0
app.py CHANGED
@@ -16,7 +16,7 @@ import ast
16
  import torch
17
 
18
  from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
19
- from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output
20
  from utils.model_loader import register_all_models
21
  from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
22
  from forensics.gradient import gradient_processing
@@ -93,34 +93,6 @@ def register_model_with_metadata(model_id, model, preprocess, postprocess, class
93
  _onnx_model_cache = {}
94
 
95
 
96
- def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
97
- try:
98
- ort_session, _, _ = get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor)
99
-
100
- # Debug: Print expected input shape from ONNX model
101
- for input_meta in ort_session.get_inputs():
102
- logger.info(f"Debug: ONNX model expected input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}")
103
-
104
- logger.info(f"Debug: preprocessed_image_np shape: {preprocessed_image_np.shape}")
105
- ort_inputs = {ort_session.get_inputs()[0].name: preprocessed_image_np}
106
- ort_outputs = ort_session.run(None, ort_inputs)
107
-
108
- logits = ort_outputs[0]
109
- logger.info(f"Debug: logits type: {type(logits)}, shape: {logits.shape}")
110
- # If the model outputs a single logit (e.g., shape (1,)), use .item() to convert to scalar
111
- # Otherwise, assume it's a batch of logits (e.g., shape (1, num_classes)) and take the first element (batch dim)
112
- # The num_classes in config.json can be misleading; rely on actual output shape.
113
-
114
- # Apply softmax to the logits to get probabilities for the classes
115
- # The softmax function in utils/utils.py now ensures a list of floats
116
- probabilities = softmax(logits[0]) # Assuming logits[0] is the relevant output for a single prediction
117
-
118
- return {"logits": logits, "probabilities": probabilities}
119
-
120
- except Exception as e:
121
- logger.error(f"Error during ONNX inference for {hf_model_id}: {e}")
122
- # Return a structure consistent with other model errors
123
- return {"logits": np.array([]), "probabilities": np.array([])}
124
 
125
  # Register all models (ONNX, HuggingFace, Gradio API)
126
  register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output)
 
16
  import torch
17
 
18
  from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
19
+ from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output, infer_onnx_model
20
  from utils.model_loader import register_all_models
21
  from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
22
  from forensics.gradient import gradient_processing
 
93
  _onnx_model_cache = {}
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Register all models (ONNX, HuggingFace, Gradio API)
98
  register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output)
utils/model_loader.py CHANGED
@@ -1,5 +1,5 @@
1
  # Add missing import for infer_onnx_model
2
- from app import infer_onnx_model
3
  # Add missing import for preprocess_onnx_input
4
  from utils.onnx_helpers import preprocess_onnx_input
5
  """
 
1
  # Add missing import for infer_onnx_model
2
+ from utils.onnx_helpers import infer_onnx_model
3
  # Add missing import for preprocess_onnx_input
4
  from utils.onnx_helpers import preprocess_onnx_input
5
  """
utils/onnx_helpers.py CHANGED
@@ -1,3 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from torchvision import transforms
3
  from PIL import Image
 
1
+ # Shared ONNX inference function for use by app.py and model_loader.py
2
+ def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
3
+ from .onnx_model_loader import get_onnx_model_from_cache, load_onnx_model_and_preprocessor
4
+ from .utils import softmax
5
+ import numpy as np
6
+ import logging
7
+ logger = logging.getLogger(__name__)
8
+ _onnx_model_cache = {}
9
+ try:
10
+ ort_session, _, _ = get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor)
11
+ for input_meta in ort_session.get_inputs():
12
+ logger.info(f"Debug: ONNX model expected input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}")
13
+ logger.info(f"Debug: preprocessed_image_np shape: {preprocessed_image_np.shape}")
14
+ ort_inputs = {ort_session.get_inputs()[0].name: preprocessed_image_np}
15
+ ort_outputs = ort_session.run(None, ort_inputs)
16
+ logits = ort_outputs[0]
17
+ logger.info(f"Debug: logits type: {type(logits)}, shape: {logits.shape}")
18
+ probabilities = softmax(logits[0])
19
+ return {"logits": logits, "probabilities": probabilities}
20
+ except Exception as e:
21
+ logger.error(f"Error during ONNX inference for {hf_model_id}: {e}")
22
+ return {"logits": np.array([]), "probabilities": np.array([])}
23
  import numpy as np
24
  from torchvision import transforms
25
  from PIL import Image