LPX55 commited on
Commit
4cb6734
·
1 Parent(s): 6555f50

feat: add Gradio API integration and ONNX preprocessing functions

Browse files
Files changed (3) hide show
  1. app.py +2 -92
  2. utils/onnx_helpers.py +45 -0
  3. utils/utils.py +43 -0
app.py CHANGED
@@ -15,7 +15,8 @@ import concurrent.futures
15
  import ast
16
  import torch
17
 
18
- from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar
 
19
  from forensics.gradient import gradient_processing
20
  from forensics.minmax import minmax_process
21
  from forensics.ela import ELA
@@ -90,48 +91,6 @@ CLASS_NAMES = {
90
  }
91
 
92
 
93
- def infer_gradio_api(image_path):
94
- client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
95
- result_dict = client.predict(
96
- input_image=handle_file(image_path),
97
- api_name="/simple_predict"
98
- )
99
- logger.info(f"Debug: Raw result_dict from Gradio API (model_8): {result_dict}, type: {type(result_dict)}")
100
- # result_dict is already a dictionary, no need for ast.literal_eval
101
- fake_probability = result_dict.get('Fake Probability', 0.0)
102
- logger.info(f"Debug: Parsed result_dict: {result_dict}, Extracted fake_probability: {fake_probability}")
103
- return {"probabilities": np.array([fake_probability])} # Return as a numpy array with one element
104
-
105
- # New preprocess function for Gradio API
106
- def preprocess_gradio_api(image: Image.Image):
107
- # The Gradio API expects a file path, so we need to save the PIL Image to a temporary file.
108
- temp_file_path = "./temp_gradio_input.png"
109
- image.save(temp_file_path)
110
- return temp_file_path
111
-
112
- # New postprocess function for Gradio API (adapting postprocess_binary_output)
113
- def postprocess_gradio_api(gradio_output, class_names):
114
- # gradio_output is expected to be a dictionary like {"probabilities": np.array([fake_prob])}
115
- probabilities_array = None
116
- if isinstance(gradio_output, dict) and "probabilities" in gradio_output:
117
- probabilities_array = gradio_output["probabilities"]
118
- elif isinstance(gradio_output, np.ndarray):
119
- probabilities_array = gradio_output
120
- else:
121
- logger.warning(f"Unexpected output type for Gradio API post-processing: {type(gradio_output)}. Expected dict with 'probabilities' or numpy.ndarray.")
122
- return {class_names[0]: 0.0, class_names[1]: 1.0}
123
-
124
- logger.info(f"Debug: Probabilities array entering postprocess_gradio_api: {probabilities_array}, type: {type(probabilities_array)}, shape: {probabilities_array.shape}")
125
-
126
- if probabilities_array is None or probabilities_array.size == 0:
127
- logger.warning("Probabilities array is None or empty after extracting from Gradio API output. Returning default scores.")
128
- return {class_names[0]: 0.0, class_names[1]: 1.0}
129
-
130
- # It should always be a single element array for fake probability
131
- fake_prob = float(probabilities_array.item())
132
- real_prob = 1.0 - fake_prob
133
-
134
- return {class_names[0]: fake_prob, class_names[1]: real_prob}
135
 
136
  def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
137
  entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
@@ -178,27 +137,6 @@ def get_onnx_model_from_cache(hf_model_id):
178
  _onnx_model_cache[hf_model_id] = load_onnx_model_and_preprocessor(hf_model_id)
179
  return _onnx_model_cache[hf_model_id]
180
 
181
- def preprocess_onnx_input(image: Image.Image, preprocessor_config: dict):
182
- # Preprocess image for ONNX model based on preprocessor_config
183
- if image.mode != 'RGB':
184
- image = image.convert('RGB')
185
-
186
- # Get image size and normalization values from preprocessor_config or use defaults
187
- # Use 'size' for initial resize and 'crop_size' for center cropping
188
- initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224})
189
- crop_size = preprocessor_config.get('crop_size', initial_resize_size['height'])
190
- mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406])
191
- std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225])
192
-
193
- transform = transforms.Compose([
194
- transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])),
195
- transforms.CenterCrop(crop_size), # Apply center crop
196
- transforms.ToTensor(),
197
- transforms.Normalize(mean=mean, std=std),
198
- ])
199
- input_tensor = transform(image)
200
- # ONNX expects numpy array with batch dimension (1, C, H, W)
201
- return input_tensor.unsqueeze(0).cpu().numpy()
202
 
203
  def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
204
  try:
@@ -229,34 +167,6 @@ def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
229
  # Return a structure consistent with other model errors
230
  return {"logits": np.array([]), "probabilities": np.array([])}
231
 
232
- def postprocess_onnx_output(onnx_output, model_config):
233
- # Get class names from model_config
234
- # Prioritize id2label, then check num_classes, otherwise default
235
- class_names_map = model_config.get('id2label')
236
- if class_names_map:
237
- class_names = [class_names_map[k] for k in sorted(class_names_map.keys())]
238
- elif model_config.get('num_classes') == 1: # Handle models that output a single value (e.g., probability of 'Fake')
239
- class_names = ['Fake', 'Real'] # Assume first class is 'Fake' and second 'Real'
240
- else:
241
- class_names = {0: 'Fake', 1: 'Real'} # Default to Fake/Real if not found or not 1 class
242
- class_names = [class_names[i] for i in sorted(class_names.keys())]
243
-
244
- probabilities = onnx_output.get("probabilities")
245
-
246
- if probabilities is not None:
247
- if model_config.get('num_classes') == 1 and len(probabilities) == 2: # Special handling for single output models
248
- # The single output is the probability of the 'Fake' class
249
- fake_prob = float(probabilities[0])
250
- real_prob = float(probabilities[1])
251
- return {class_names[0]: fake_prob, class_names[1]: real_prob}
252
- elif len(probabilities) == len(class_names):
253
- return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
254
- else:
255
- logger.warning("ONNX post-processing: Probabilities length mismatch with class names.")
256
- return {name: 0.0 for name in class_names}
257
- else:
258
- logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.")
259
- return {name: 0.0 for name in class_names}
260
 
261
  # Register the ONNX quantized model
262
  # Dummy entry for ONNX model to be loaded dynamically
 
15
  import ast
16
  import torch
17
 
18
+ from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
19
+ from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output
20
  from forensics.gradient import gradient_processing
21
  from forensics.minmax import minmax_process
22
  from forensics.ela import ELA
 
91
  }
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
96
  entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
 
137
  _onnx_model_cache[hf_model_id] = load_onnx_model_and_preprocessor(hf_model_id)
138
  return _onnx_model_cache[hf_model_id]
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
142
  try:
 
167
  # Return a structure consistent with other model errors
168
  return {"logits": np.array([]), "probabilities": np.array([])}
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  # Register the ONNX quantized model
172
  # Dummy entry for ONNX model to be loaded dynamically
utils/onnx_helpers.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import logging
5
+
6
+ def preprocess_onnx_input(image, preprocessor_config):
7
+ if image.mode != 'RGB':
8
+ image = image.convert('RGB')
9
+ initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224})
10
+ crop_size = preprocessor_config.get('crop_size', initial_resize_size['height'])
11
+ mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406])
12
+ std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225])
13
+ transform = transforms.Compose([
14
+ transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])),
15
+ transforms.CenterCrop(crop_size),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=mean, std=std),
18
+ ])
19
+ input_tensor = transform(image)
20
+ return input_tensor.unsqueeze(0).cpu().numpy()
21
+
22
+ def postprocess_onnx_output(onnx_output, model_config):
23
+ logger = logging.getLogger(__name__)
24
+ class_names_map = model_config.get('id2label')
25
+ if class_names_map:
26
+ class_names = [class_names_map[k] for k in sorted(class_names_map.keys())]
27
+ elif model_config.get('num_classes') == 1:
28
+ class_names = ['Fake', 'Real']
29
+ else:
30
+ class_names = {0: 'Fake', 1: 'Real'}
31
+ class_names = [class_names[i] for i in sorted(class_names.keys())]
32
+ probabilities = onnx_output.get("probabilities")
33
+ if probabilities is not None:
34
+ if model_config.get('num_classes') == 1 and len(probabilities) == 2:
35
+ fake_prob = float(probabilities[0])
36
+ real_prob = float(probabilities[1])
37
+ return {class_names[0]: fake_prob, class_names[1]: real_prob}
38
+ elif len(probabilities) == len(class_names):
39
+ return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
40
+ else:
41
+ logger.warning("ONNX post-processing: Probabilities length mismatch with class names.")
42
+ return {name: 0.0 for name in class_names}
43
+ else:
44
+ logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.")
45
+ return {name: 0.0 for name in class_names}
utils/utils.py CHANGED
@@ -1,3 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def preprocess_resize_256(image):
2
  if image.mode != 'RGB':
3
  image = image.convert('RGB')
 
1
+ def infer_gradio_api(image_path):
2
+ from gradio_client import Client, handle_file
3
+ import numpy as np
4
+ import logging
5
+ logger = logging.getLogger(__name__)
6
+ client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
7
+ result_dict = client.predict(
8
+ input_image=handle_file(image_path),
9
+ api_name="/simple_predict"
10
+ )
11
+ logger.info(f"Debug: Raw result_dict from Gradio API (model_8): {result_dict}, type: {type(result_dict)}")
12
+ fake_probability = result_dict.get('Fake Probability', 0.0)
13
+ logger.info(f"Debug: Parsed result_dict: {result_dict}, Extracted fake_probability: {fake_probability}")
14
+ return {"probabilities": np.array([fake_probability])}
15
+
16
+ def preprocess_gradio_api(image):
17
+ temp_file_path = "./temp_gradio_input.png"
18
+ image.save(temp_file_path)
19
+ return temp_file_path
20
+
21
+ def postprocess_gradio_api(gradio_output, class_names):
22
+ import numpy as np
23
+ import logging
24
+ logger = logging.getLogger(__name__)
25
+ probabilities_array = None
26
+ if isinstance(gradio_output, dict) and "probabilities" in gradio_output:
27
+ probabilities_array = gradio_output["probabilities"]
28
+ elif isinstance(gradio_output, np.ndarray):
29
+ probabilities_array = gradio_output
30
+ else:
31
+ logger.warning(f"Unexpected output type for Gradio API post-processing: {type(gradio_output)}. Expected dict with 'probabilities' or numpy.ndarray.")
32
+ return {class_names[0]: 0.0, class_names[1]: 1.0}
33
+
34
+ logger.info(f"Debug: Probabilities array entering postprocess_gradio_api: {probabilities_array}, type: {type(probabilities_array)}, shape: {getattr(probabilities_array, 'shape', None)}")
35
+
36
+ if probabilities_array is None or probabilities_array.size == 0:
37
+ logger.warning("Probabilities array is None or empty after extracting from Gradio API output. Returning default scores.")
38
+ return {class_names[0]: 0.0, class_names[1]: 1.0}
39
+
40
+ fake_prob = float(probabilities_array.item())
41
+ real_prob = 1.0 - fake_prob
42
+ return {class_names[0]: fake_prob, class_names[1]: real_prob}
43
+
44
  def preprocess_resize_256(image):
45
  if image.mode != 'RGB':
46
  image = image.convert('RGB')