LPX55
commited on
Commit
·
4cb6734
1
Parent(s):
6555f50
feat: add Gradio API integration and ONNX preprocessing functions
Browse files- app.py +2 -92
- utils/onnx_helpers.py +45 -0
- utils/utils.py +43 -0
app.py
CHANGED
@@ -15,7 +15,8 @@ import concurrent.futures
|
|
15 |
import ast
|
16 |
import torch
|
17 |
|
18 |
-
from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar
|
|
|
19 |
from forensics.gradient import gradient_processing
|
20 |
from forensics.minmax import minmax_process
|
21 |
from forensics.ela import ELA
|
@@ -90,48 +91,6 @@ CLASS_NAMES = {
|
|
90 |
}
|
91 |
|
92 |
|
93 |
-
def infer_gradio_api(image_path):
|
94 |
-
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
95 |
-
result_dict = client.predict(
|
96 |
-
input_image=handle_file(image_path),
|
97 |
-
api_name="/simple_predict"
|
98 |
-
)
|
99 |
-
logger.info(f"Debug: Raw result_dict from Gradio API (model_8): {result_dict}, type: {type(result_dict)}")
|
100 |
-
# result_dict is already a dictionary, no need for ast.literal_eval
|
101 |
-
fake_probability = result_dict.get('Fake Probability', 0.0)
|
102 |
-
logger.info(f"Debug: Parsed result_dict: {result_dict}, Extracted fake_probability: {fake_probability}")
|
103 |
-
return {"probabilities": np.array([fake_probability])} # Return as a numpy array with one element
|
104 |
-
|
105 |
-
# New preprocess function for Gradio API
|
106 |
-
def preprocess_gradio_api(image: Image.Image):
|
107 |
-
# The Gradio API expects a file path, so we need to save the PIL Image to a temporary file.
|
108 |
-
temp_file_path = "./temp_gradio_input.png"
|
109 |
-
image.save(temp_file_path)
|
110 |
-
return temp_file_path
|
111 |
-
|
112 |
-
# New postprocess function for Gradio API (adapting postprocess_binary_output)
|
113 |
-
def postprocess_gradio_api(gradio_output, class_names):
|
114 |
-
# gradio_output is expected to be a dictionary like {"probabilities": np.array([fake_prob])}
|
115 |
-
probabilities_array = None
|
116 |
-
if isinstance(gradio_output, dict) and "probabilities" in gradio_output:
|
117 |
-
probabilities_array = gradio_output["probabilities"]
|
118 |
-
elif isinstance(gradio_output, np.ndarray):
|
119 |
-
probabilities_array = gradio_output
|
120 |
-
else:
|
121 |
-
logger.warning(f"Unexpected output type for Gradio API post-processing: {type(gradio_output)}. Expected dict with 'probabilities' or numpy.ndarray.")
|
122 |
-
return {class_names[0]: 0.0, class_names[1]: 1.0}
|
123 |
-
|
124 |
-
logger.info(f"Debug: Probabilities array entering postprocess_gradio_api: {probabilities_array}, type: {type(probabilities_array)}, shape: {probabilities_array.shape}")
|
125 |
-
|
126 |
-
if probabilities_array is None or probabilities_array.size == 0:
|
127 |
-
logger.warning("Probabilities array is None or empty after extracting from Gradio API output. Returning default scores.")
|
128 |
-
return {class_names[0]: 0.0, class_names[1]: 1.0}
|
129 |
-
|
130 |
-
# It should always be a single element array for fake probability
|
131 |
-
fake_prob = float(probabilities_array.item())
|
132 |
-
real_prob = 1.0 - fake_prob
|
133 |
-
|
134 |
-
return {class_names[0]: fake_prob, class_names[1]: real_prob}
|
135 |
|
136 |
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
|
137 |
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
|
@@ -178,27 +137,6 @@ def get_onnx_model_from_cache(hf_model_id):
|
|
178 |
_onnx_model_cache[hf_model_id] = load_onnx_model_and_preprocessor(hf_model_id)
|
179 |
return _onnx_model_cache[hf_model_id]
|
180 |
|
181 |
-
def preprocess_onnx_input(image: Image.Image, preprocessor_config: dict):
|
182 |
-
# Preprocess image for ONNX model based on preprocessor_config
|
183 |
-
if image.mode != 'RGB':
|
184 |
-
image = image.convert('RGB')
|
185 |
-
|
186 |
-
# Get image size and normalization values from preprocessor_config or use defaults
|
187 |
-
# Use 'size' for initial resize and 'crop_size' for center cropping
|
188 |
-
initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224})
|
189 |
-
crop_size = preprocessor_config.get('crop_size', initial_resize_size['height'])
|
190 |
-
mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406])
|
191 |
-
std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225])
|
192 |
-
|
193 |
-
transform = transforms.Compose([
|
194 |
-
transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])),
|
195 |
-
transforms.CenterCrop(crop_size), # Apply center crop
|
196 |
-
transforms.ToTensor(),
|
197 |
-
transforms.Normalize(mean=mean, std=std),
|
198 |
-
])
|
199 |
-
input_tensor = transform(image)
|
200 |
-
# ONNX expects numpy array with batch dimension (1, C, H, W)
|
201 |
-
return input_tensor.unsqueeze(0).cpu().numpy()
|
202 |
|
203 |
def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
204 |
try:
|
@@ -229,34 +167,6 @@ def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
|
229 |
# Return a structure consistent with other model errors
|
230 |
return {"logits": np.array([]), "probabilities": np.array([])}
|
231 |
|
232 |
-
def postprocess_onnx_output(onnx_output, model_config):
|
233 |
-
# Get class names from model_config
|
234 |
-
# Prioritize id2label, then check num_classes, otherwise default
|
235 |
-
class_names_map = model_config.get('id2label')
|
236 |
-
if class_names_map:
|
237 |
-
class_names = [class_names_map[k] for k in sorted(class_names_map.keys())]
|
238 |
-
elif model_config.get('num_classes') == 1: # Handle models that output a single value (e.g., probability of 'Fake')
|
239 |
-
class_names = ['Fake', 'Real'] # Assume first class is 'Fake' and second 'Real'
|
240 |
-
else:
|
241 |
-
class_names = {0: 'Fake', 1: 'Real'} # Default to Fake/Real if not found or not 1 class
|
242 |
-
class_names = [class_names[i] for i in sorted(class_names.keys())]
|
243 |
-
|
244 |
-
probabilities = onnx_output.get("probabilities")
|
245 |
-
|
246 |
-
if probabilities is not None:
|
247 |
-
if model_config.get('num_classes') == 1 and len(probabilities) == 2: # Special handling for single output models
|
248 |
-
# The single output is the probability of the 'Fake' class
|
249 |
-
fake_prob = float(probabilities[0])
|
250 |
-
real_prob = float(probabilities[1])
|
251 |
-
return {class_names[0]: fake_prob, class_names[1]: real_prob}
|
252 |
-
elif len(probabilities) == len(class_names):
|
253 |
-
return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
|
254 |
-
else:
|
255 |
-
logger.warning("ONNX post-processing: Probabilities length mismatch with class names.")
|
256 |
-
return {name: 0.0 for name in class_names}
|
257 |
-
else:
|
258 |
-
logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.")
|
259 |
-
return {name: 0.0 for name in class_names}
|
260 |
|
261 |
# Register the ONNX quantized model
|
262 |
# Dummy entry for ONNX model to be loaded dynamically
|
|
|
15 |
import ast
|
16 |
import torch
|
17 |
|
18 |
+
from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
|
19 |
+
from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output
|
20 |
from forensics.gradient import gradient_processing
|
21 |
from forensics.minmax import minmax_process
|
22 |
from forensics.ela import ELA
|
|
|
91 |
}
|
92 |
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
|
96 |
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
|
|
|
137 |
_onnx_model_cache[hf_model_id] = load_onnx_model_and_preprocessor(hf_model_id)
|
138 |
return _onnx_model_cache[hf_model_id]
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
142 |
try:
|
|
|
167 |
# Return a structure consistent with other model errors
|
168 |
return {"logits": np.array([]), "probabilities": np.array([])}
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
# Register the ONNX quantized model
|
172 |
# Dummy entry for ONNX model to be loaded dynamically
|
utils/onnx_helpers.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torchvision import transforms
|
3 |
+
from PIL import Image
|
4 |
+
import logging
|
5 |
+
|
6 |
+
def preprocess_onnx_input(image, preprocessor_config):
|
7 |
+
if image.mode != 'RGB':
|
8 |
+
image = image.convert('RGB')
|
9 |
+
initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224})
|
10 |
+
crop_size = preprocessor_config.get('crop_size', initial_resize_size['height'])
|
11 |
+
mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406])
|
12 |
+
std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225])
|
13 |
+
transform = transforms.Compose([
|
14 |
+
transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])),
|
15 |
+
transforms.CenterCrop(crop_size),
|
16 |
+
transforms.ToTensor(),
|
17 |
+
transforms.Normalize(mean=mean, std=std),
|
18 |
+
])
|
19 |
+
input_tensor = transform(image)
|
20 |
+
return input_tensor.unsqueeze(0).cpu().numpy()
|
21 |
+
|
22 |
+
def postprocess_onnx_output(onnx_output, model_config):
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
class_names_map = model_config.get('id2label')
|
25 |
+
if class_names_map:
|
26 |
+
class_names = [class_names_map[k] for k in sorted(class_names_map.keys())]
|
27 |
+
elif model_config.get('num_classes') == 1:
|
28 |
+
class_names = ['Fake', 'Real']
|
29 |
+
else:
|
30 |
+
class_names = {0: 'Fake', 1: 'Real'}
|
31 |
+
class_names = [class_names[i] for i in sorted(class_names.keys())]
|
32 |
+
probabilities = onnx_output.get("probabilities")
|
33 |
+
if probabilities is not None:
|
34 |
+
if model_config.get('num_classes') == 1 and len(probabilities) == 2:
|
35 |
+
fake_prob = float(probabilities[0])
|
36 |
+
real_prob = float(probabilities[1])
|
37 |
+
return {class_names[0]: fake_prob, class_names[1]: real_prob}
|
38 |
+
elif len(probabilities) == len(class_names):
|
39 |
+
return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
|
40 |
+
else:
|
41 |
+
logger.warning("ONNX post-processing: Probabilities length mismatch with class names.")
|
42 |
+
return {name: 0.0 for name in class_names}
|
43 |
+
else:
|
44 |
+
logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.")
|
45 |
+
return {name: 0.0 for name in class_names}
|
utils/utils.py
CHANGED
@@ -1,3 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def preprocess_resize_256(image):
|
2 |
if image.mode != 'RGB':
|
3 |
image = image.convert('RGB')
|
|
|
1 |
+
def infer_gradio_api(image_path):
|
2 |
+
from gradio_client import Client, handle_file
|
3 |
+
import numpy as np
|
4 |
+
import logging
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
7 |
+
result_dict = client.predict(
|
8 |
+
input_image=handle_file(image_path),
|
9 |
+
api_name="/simple_predict"
|
10 |
+
)
|
11 |
+
logger.info(f"Debug: Raw result_dict from Gradio API (model_8): {result_dict}, type: {type(result_dict)}")
|
12 |
+
fake_probability = result_dict.get('Fake Probability', 0.0)
|
13 |
+
logger.info(f"Debug: Parsed result_dict: {result_dict}, Extracted fake_probability: {fake_probability}")
|
14 |
+
return {"probabilities": np.array([fake_probability])}
|
15 |
+
|
16 |
+
def preprocess_gradio_api(image):
|
17 |
+
temp_file_path = "./temp_gradio_input.png"
|
18 |
+
image.save(temp_file_path)
|
19 |
+
return temp_file_path
|
20 |
+
|
21 |
+
def postprocess_gradio_api(gradio_output, class_names):
|
22 |
+
import numpy as np
|
23 |
+
import logging
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
probabilities_array = None
|
26 |
+
if isinstance(gradio_output, dict) and "probabilities" in gradio_output:
|
27 |
+
probabilities_array = gradio_output["probabilities"]
|
28 |
+
elif isinstance(gradio_output, np.ndarray):
|
29 |
+
probabilities_array = gradio_output
|
30 |
+
else:
|
31 |
+
logger.warning(f"Unexpected output type for Gradio API post-processing: {type(gradio_output)}. Expected dict with 'probabilities' or numpy.ndarray.")
|
32 |
+
return {class_names[0]: 0.0, class_names[1]: 1.0}
|
33 |
+
|
34 |
+
logger.info(f"Debug: Probabilities array entering postprocess_gradio_api: {probabilities_array}, type: {type(probabilities_array)}, shape: {getattr(probabilities_array, 'shape', None)}")
|
35 |
+
|
36 |
+
if probabilities_array is None or probabilities_array.size == 0:
|
37 |
+
logger.warning("Probabilities array is None or empty after extracting from Gradio API output. Returning default scores.")
|
38 |
+
return {class_names[0]: 0.0, class_names[1]: 1.0}
|
39 |
+
|
40 |
+
fake_prob = float(probabilities_array.item())
|
41 |
+
real_prob = 1.0 - fake_prob
|
42 |
+
return {class_names[0]: fake_prob, class_names[1]: real_prob}
|
43 |
+
|
44 |
def preprocess_resize_256(image):
|
45 |
if image.mode != 'RGB':
|
46 |
image = image.convert('RGB')
|