LPX55
commited on
Commit
·
ceff40b
1
Parent(s):
3232315
fix: add missing import for infer_onnx_model in model_loader and update its definition in onnx_helpers
Browse files- app.py +1 -29
- utils/model_loader.py +1 -1
- utils/onnx_helpers.py +22 -0
app.py
CHANGED
@@ -16,7 +16,7 @@ import ast
|
|
16 |
import torch
|
17 |
|
18 |
from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
|
19 |
-
from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output
|
20 |
from utils.model_loader import register_all_models
|
21 |
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
|
22 |
from forensics.gradient import gradient_processing
|
@@ -93,34 +93,6 @@ def register_model_with_metadata(model_id, model, preprocess, postprocess, class
|
|
93 |
_onnx_model_cache = {}
|
94 |
|
95 |
|
96 |
-
def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
97 |
-
try:
|
98 |
-
ort_session, _, _ = get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor)
|
99 |
-
|
100 |
-
# Debug: Print expected input shape from ONNX model
|
101 |
-
for input_meta in ort_session.get_inputs():
|
102 |
-
logger.info(f"Debug: ONNX model expected input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}")
|
103 |
-
|
104 |
-
logger.info(f"Debug: preprocessed_image_np shape: {preprocessed_image_np.shape}")
|
105 |
-
ort_inputs = {ort_session.get_inputs()[0].name: preprocessed_image_np}
|
106 |
-
ort_outputs = ort_session.run(None, ort_inputs)
|
107 |
-
|
108 |
-
logits = ort_outputs[0]
|
109 |
-
logger.info(f"Debug: logits type: {type(logits)}, shape: {logits.shape}")
|
110 |
-
# If the model outputs a single logit (e.g., shape (1,)), use .item() to convert to scalar
|
111 |
-
# Otherwise, assume it's a batch of logits (e.g., shape (1, num_classes)) and take the first element (batch dim)
|
112 |
-
# The num_classes in config.json can be misleading; rely on actual output shape.
|
113 |
-
|
114 |
-
# Apply softmax to the logits to get probabilities for the classes
|
115 |
-
# The softmax function in utils/utils.py now ensures a list of floats
|
116 |
-
probabilities = softmax(logits[0]) # Assuming logits[0] is the relevant output for a single prediction
|
117 |
-
|
118 |
-
return {"logits": logits, "probabilities": probabilities}
|
119 |
-
|
120 |
-
except Exception as e:
|
121 |
-
logger.error(f"Error during ONNX inference for {hf_model_id}: {e}")
|
122 |
-
# Return a structure consistent with other model errors
|
123 |
-
return {"logits": np.array([]), "probabilities": np.array([])}
|
124 |
|
125 |
# Register all models (ONNX, HuggingFace, Gradio API)
|
126 |
register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output)
|
|
|
16 |
import torch
|
17 |
|
18 |
from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
|
19 |
+
from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output, infer_onnx_model
|
20 |
from utils.model_loader import register_all_models
|
21 |
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
|
22 |
from forensics.gradient import gradient_processing
|
|
|
93 |
_onnx_model_cache = {}
|
94 |
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
# Register all models (ONNX, HuggingFace, Gradio API)
|
98 |
register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output)
|
utils/model_loader.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# Add missing import for infer_onnx_model
|
2 |
-
from
|
3 |
# Add missing import for preprocess_onnx_input
|
4 |
from utils.onnx_helpers import preprocess_onnx_input
|
5 |
"""
|
|
|
1 |
# Add missing import for infer_onnx_model
|
2 |
+
from utils.onnx_helpers import infer_onnx_model
|
3 |
# Add missing import for preprocess_onnx_input
|
4 |
from utils.onnx_helpers import preprocess_onnx_input
|
5 |
"""
|
utils/onnx_helpers.py
CHANGED
@@ -1,3 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
from torchvision import transforms
|
3 |
from PIL import Image
|
|
|
1 |
+
# Shared ONNX inference function for use by app.py and model_loader.py
|
2 |
+
def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
3 |
+
from .onnx_model_loader import get_onnx_model_from_cache, load_onnx_model_and_preprocessor
|
4 |
+
from .utils import softmax
|
5 |
+
import numpy as np
|
6 |
+
import logging
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
_onnx_model_cache = {}
|
9 |
+
try:
|
10 |
+
ort_session, _, _ = get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor)
|
11 |
+
for input_meta in ort_session.get_inputs():
|
12 |
+
logger.info(f"Debug: ONNX model expected input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}")
|
13 |
+
logger.info(f"Debug: preprocessed_image_np shape: {preprocessed_image_np.shape}")
|
14 |
+
ort_inputs = {ort_session.get_inputs()[0].name: preprocessed_image_np}
|
15 |
+
ort_outputs = ort_session.run(None, ort_inputs)
|
16 |
+
logits = ort_outputs[0]
|
17 |
+
logger.info(f"Debug: logits type: {type(logits)}, shape: {logits.shape}")
|
18 |
+
probabilities = softmax(logits[0])
|
19 |
+
return {"logits": logits, "probabilities": probabilities}
|
20 |
+
except Exception as e:
|
21 |
+
logger.error(f"Error during ONNX inference for {hf_model_id}: {e}")
|
22 |
+
return {"logits": np.array([]), "probabilities": np.array([])}
|
23 |
import numpy as np
|
24 |
from torchvision import transforms
|
25 |
from PIL import Image
|