LPX commited on
Commit
ac9c2b2
·
1 Parent(s): 932e7b4

feat: enhance model registration with metadata

Browse files

- Introduced a new function `register_model_with_metadata` to include display name, contributor, and model path in model entries.
- Updated model registration calls for all models to utilize the new function, enhancing the metadata associated with each model.
- Modified `ModelEntry` class in `registry.py` to accommodate additional metadata fields.

Files changed (2) hide show
  1. app_mcp.py +63 -55
  2. forensics/registry.py +6 -2
app_mcp.py CHANGED
@@ -16,7 +16,7 @@ from utils.utils import softmax, augment_image, convert_pil_to_bytes
16
  from utils.gradient import gradient_processing
17
  from utils.minmax import preprocess as minmax_preprocess
18
  from utils.ela import genELA as ELA
19
- from forensics.registry import register_model, MODEL_REGISTRY
20
 
21
 
22
  # Configure logging
@@ -107,25 +107,30 @@ def postprocess_logits(outputs, class_names):
107
  probabilities = softmax(logits)
108
  return {class_names[i]: probabilities[i] for i in range(len(class_names))}
109
 
 
 
 
 
 
 
 
 
 
 
 
110
  # Load and register models (example for two models)
111
  image_processor_1 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_1"], use_fast=True)
112
  model_1 = Swinv2ForImageClassification.from_pretrained(MODEL_PATHS["model_1"]).to(device)
113
  clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
114
- register_model(
115
- "model_1",
116
- clf_1,
117
- preprocess_resize_256,
118
- postprocess_pipeline,
119
- CLASS_NAMES["model_1"]
120
  )
121
 
122
  clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
123
- register_model(
124
- "model_2",
125
- clf_2,
126
- preprocess_resize_224,
127
- postprocess_pipeline,
128
- CLASS_NAMES["model_2"]
129
  )
130
 
131
  # Register remaining models
@@ -144,12 +149,9 @@ def model3_infer(image):
144
  with torch.no_grad():
145
  outputs = model_3(**inputs)
146
  return outputs
147
- register_model(
148
- "model_3",
149
- model3_infer,
150
- preprocess_256,
151
- postprocess_logits_model3,
152
- CLASS_NAMES["model_3"]
153
  )
154
 
155
  feature_extractor_4 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_4"], device=device)
@@ -163,52 +165,37 @@ def postprocess_logits_model4(outputs, class_names):
163
  logits = outputs.logits.cpu().numpy()[0]
164
  probabilities = softmax(logits)
165
  return {class_names[i]: probabilities[i] for i in range(len(class_names))}
166
- register_model(
167
- "model_4",
168
- model4_infer,
169
- preprocess_256,
170
- postprocess_logits_model4,
171
- CLASS_NAMES["model_4"]
172
  )
173
 
174
  clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
175
- register_model(
176
- "model_5",
177
- clf_5,
178
- preprocess_resize_224,
179
- postprocess_pipeline,
180
- CLASS_NAMES["model_5"]
181
  )
182
 
183
  clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
184
- register_model(
185
- "model_5b",
186
- clf_5b,
187
- preprocess_resize_224,
188
- postprocess_pipeline,
189
- CLASS_NAMES["model_5b"]
190
  )
191
 
192
  image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
193
  model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
194
  clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
195
- register_model(
196
- "model_6",
197
- clf_6,
198
- preprocess_resize_224,
199
- postprocess_pipeline,
200
- CLASS_NAMES["model_6"]
201
  )
202
 
203
  image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
204
  model_7 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_7"]).to(device)
205
  clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
206
- register_model(
207
- "model_7",
208
- clf_7,
209
- preprocess_resize_224,
210
- postprocess_pipeline,
211
- CLASS_NAMES["model_7"]
212
  )
213
 
214
  # Generic inference function
@@ -218,11 +205,28 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
218
  img = entry.preprocess(image)
219
  try:
220
  result = entry.model(img)
221
- result = entry.postprocess(result, entry.class_names)
222
- # Add confidence threshold logic if needed
223
- return result
 
 
 
 
 
 
 
 
 
 
224
  except Exception as e:
225
- return {"error": str(e)}
 
 
 
 
 
 
 
226
 
227
  # Update predict_image to use all registered models in order
228
 
@@ -287,8 +291,12 @@ with gr.Blocks(css="#post-gallery { overflow: hidden !important;} .grid-wrap{ ov
287
 
288
 
289
  with gr.Column(scale=2):
290
- # Use Gradio-native Dataframe to display results
291
- results_table = gr.Dataframe(label="Model Predictions", headers=None, datatype="auto")
 
 
 
 
292
  forensics_gallery = gr.Gallery(label="Post Processed Images", visible=True, columns=[4], rows=[2], container=False, height="auto", object_fit="contain", elem_id="post-gallery")
293
 
294
  outputs = [image_output, forensics_gallery, results_table]
 
16
  from utils.gradient import gradient_processing
17
  from utils.minmax import preprocess as minmax_preprocess
18
  from utils.ela import genELA as ELA
19
+ from forensics.registry import register_model, MODEL_REGISTRY, ModelEntry
20
 
21
 
22
  # Configure logging
 
107
  probabilities = softmax(logits)
108
  return {class_names[i]: probabilities[i] for i in range(len(class_names))}
109
 
110
+ # Expand ModelEntry to include metadata
111
+ # (Assume ModelEntry is updated in registry.py to accept display_name, contributor, model_path)
112
+ # If not, we will update registry.py accordingly after this.
113
+
114
+ def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path):
115
+ entry = ModelEntry(model, preprocess, postprocess, class_names)
116
+ entry.display_name = display_name
117
+ entry.contributor = contributor
118
+ entry.model_path = model_path
119
+ MODEL_REGISTRY[model_id] = entry
120
+
121
  # Load and register models (example for two models)
122
  image_processor_1 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_1"], use_fast=True)
123
  model_1 = Swinv2ForImageClassification.from_pretrained(MODEL_PATHS["model_1"]).to(device)
124
  clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
125
+ register_model_with_metadata(
126
+ "model_1", clf_1, preprocess_resize_256, postprocess_pipeline, CLASS_NAMES["model_1"],
127
+ display_name="SwinV2 Based", contributor="haywoodsloan", model_path=MODEL_PATHS["model_1"]
 
 
 
128
  )
129
 
130
  clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
131
+ register_model_with_metadata(
132
+ "model_2", clf_2, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_2"],
133
+ display_name="ViT Based", contributor="Heem2", model_path=MODEL_PATHS["model_2"]
 
 
 
134
  )
135
 
136
  # Register remaining models
 
149
  with torch.no_grad():
150
  outputs = model_3(**inputs)
151
  return outputs
152
+ register_model_with_metadata(
153
+ "model_3", model3_infer, preprocess_256, postprocess_logits_model3, CLASS_NAMES["model_3"],
154
+ display_name="SDXL Dataset", contributor="Organika", model_path=MODEL_PATHS["model_3"]
 
 
 
155
  )
156
 
157
  feature_extractor_4 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_4"], device=device)
 
165
  logits = outputs.logits.cpu().numpy()[0]
166
  probabilities = softmax(logits)
167
  return {class_names[i]: probabilities[i] for i in range(len(class_names))}
168
+ register_model_with_metadata(
169
+ "model_4", model4_infer, preprocess_256, postprocess_logits_model4, CLASS_NAMES["model_4"],
170
+ display_name="SDXL + FLUX", contributor="cmckinle", model_path=MODEL_PATHS["model_4"]
 
 
 
171
  )
172
 
173
  clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
174
+ register_model_with_metadata(
175
+ "model_5", clf_5, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_5"],
176
+ display_name="Vit Based", contributor="prithivMLmods", model_path=MODEL_PATHS["model_5"]
 
 
 
177
  )
178
 
179
  clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
180
+ register_model_with_metadata(
181
+ "model_5b", clf_5b, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_5b"],
182
+ display_name="Vit Based, Newer Dataset", contributor="prithivMLmods", model_path=MODEL_PATHS["model_5b"]
 
 
 
183
  )
184
 
185
  image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
186
  model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
187
  clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
188
+ register_model_with_metadata(
189
+ "model_6", clf_6, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_6"],
190
+ display_name="Swin, Midj + SDXL", contributor="ideepankarsharma2003", model_path=MODEL_PATHS["model_6"]
 
 
 
191
  )
192
 
193
  image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
194
  model_7 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_7"]).to(device)
195
  clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
196
+ register_model_with_metadata(
197
+ "model_7", clf_7, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_7"],
198
+ display_name="ViT", contributor="temp", model_path=MODEL_PATHS["model_7"]
 
 
 
199
  )
200
 
201
  # Generic inference function
 
205
  img = entry.preprocess(image)
206
  try:
207
  result = entry.model(img)
208
+ scores = entry.postprocess(result, entry.class_names)
209
+ # Flatten output for Dataframe: include metadata and both class scores
210
+ ai_score = scores.get(entry.class_names[0], 0.0)
211
+ real_score = scores.get(entry.class_names[1], 0.0)
212
+ label = "AI" if ai_score >= confidence_threshold else ("REAL" if real_score >= confidence_threshold else "UNCERTAIN")
213
+ return {
214
+ "Model": entry.display_name,
215
+ "Contributor": entry.contributor,
216
+ "HF Model Path": entry.model_path,
217
+ "AI Score": ai_score,
218
+ "Real Score": real_score,
219
+ "Label": label
220
+ }
221
  except Exception as e:
222
+ return {
223
+ "Model": entry.display_name,
224
+ "Contributor": entry.contributor,
225
+ "HF Model Path": entry.model_path,
226
+ "AI Score": None,
227
+ "Real Score": None,
228
+ "Label": f"Error: {str(e)}"
229
+ }
230
 
231
  # Update predict_image to use all registered models in order
232
 
 
291
 
292
 
293
  with gr.Column(scale=2):
294
+ # Use Gradio-native Dataframe to display results with headers
295
+ results_table = gr.Dataframe(
296
+ label="Model Predictions",
297
+ headers=["Model", "Contributor", "HF Model Path", "AI Score", "Real Score", "Label"],
298
+ datatype=["str", "str", "str", "number", "number", "str"]
299
+ )
300
  forensics_gallery = gr.Gallery(label="Post Processed Images", visible=True, columns=[4], rows=[2], container=False, height="auto", object_fit="contain", elem_id="post-gallery")
301
 
302
  outputs = [image_output, forensics_gallery, results_table]
forensics/registry.py CHANGED
@@ -1,11 +1,15 @@
1
- from typing import Callable, Dict, Any, List
2
 
3
  class ModelEntry:
4
- def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str]):
 
5
  self.model = model
6
  self.preprocess = preprocess
7
  self.postprocess = postprocess
8
  self.class_names = class_names
 
 
 
9
 
10
  MODEL_REGISTRY: Dict[str, ModelEntry] = {}
11
 
 
1
+ from typing import Callable, Dict, Any, List, Optional
2
 
3
  class ModelEntry:
4
+ def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str],
5
+ display_name: Optional[str] = None, contributor: Optional[str] = None, model_path: Optional[str] = None):
6
  self.model = model
7
  self.preprocess = preprocess
8
  self.postprocess = postprocess
9
  self.class_names = class_names
10
+ self.display_name = display_name
11
+ self.contributor = contributor
12
+ self.model_path = model_path
13
 
14
  MODEL_REGISTRY: Dict[str, ModelEntry] = {}
15