LPX
commited on
Commit
·
ac9c2b2
1
Parent(s):
932e7b4
feat: enhance model registration with metadata
Browse files- Introduced a new function `register_model_with_metadata` to include display name, contributor, and model path in model entries.
- Updated model registration calls for all models to utilize the new function, enhancing the metadata associated with each model.
- Modified `ModelEntry` class in `registry.py` to accommodate additional metadata fields.
- app_mcp.py +63 -55
- forensics/registry.py +6 -2
app_mcp.py
CHANGED
@@ -16,7 +16,7 @@ from utils.utils import softmax, augment_image, convert_pil_to_bytes
|
|
16 |
from utils.gradient import gradient_processing
|
17 |
from utils.minmax import preprocess as minmax_preprocess
|
18 |
from utils.ela import genELA as ELA
|
19 |
-
from forensics.registry import register_model, MODEL_REGISTRY
|
20 |
|
21 |
|
22 |
# Configure logging
|
@@ -107,25 +107,30 @@ def postprocess_logits(outputs, class_names):
|
|
107 |
probabilities = softmax(logits)
|
108 |
return {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
# Load and register models (example for two models)
|
111 |
image_processor_1 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_1"], use_fast=True)
|
112 |
model_1 = Swinv2ForImageClassification.from_pretrained(MODEL_PATHS["model_1"]).to(device)
|
113 |
clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
|
114 |
-
|
115 |
-
"model_1",
|
116 |
-
|
117 |
-
preprocess_resize_256,
|
118 |
-
postprocess_pipeline,
|
119 |
-
CLASS_NAMES["model_1"]
|
120 |
)
|
121 |
|
122 |
clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
|
123 |
-
|
124 |
-
"model_2",
|
125 |
-
|
126 |
-
preprocess_resize_224,
|
127 |
-
postprocess_pipeline,
|
128 |
-
CLASS_NAMES["model_2"]
|
129 |
)
|
130 |
|
131 |
# Register remaining models
|
@@ -144,12 +149,9 @@ def model3_infer(image):
|
|
144 |
with torch.no_grad():
|
145 |
outputs = model_3(**inputs)
|
146 |
return outputs
|
147 |
-
|
148 |
-
"model_3",
|
149 |
-
|
150 |
-
preprocess_256,
|
151 |
-
postprocess_logits_model3,
|
152 |
-
CLASS_NAMES["model_3"]
|
153 |
)
|
154 |
|
155 |
feature_extractor_4 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_4"], device=device)
|
@@ -163,52 +165,37 @@ def postprocess_logits_model4(outputs, class_names):
|
|
163 |
logits = outputs.logits.cpu().numpy()[0]
|
164 |
probabilities = softmax(logits)
|
165 |
return {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
166 |
-
|
167 |
-
"model_4",
|
168 |
-
|
169 |
-
preprocess_256,
|
170 |
-
postprocess_logits_model4,
|
171 |
-
CLASS_NAMES["model_4"]
|
172 |
)
|
173 |
|
174 |
clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
|
175 |
-
|
176 |
-
"model_5",
|
177 |
-
|
178 |
-
preprocess_resize_224,
|
179 |
-
postprocess_pipeline,
|
180 |
-
CLASS_NAMES["model_5"]
|
181 |
)
|
182 |
|
183 |
clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
|
184 |
-
|
185 |
-
"model_5b",
|
186 |
-
|
187 |
-
preprocess_resize_224,
|
188 |
-
postprocess_pipeline,
|
189 |
-
CLASS_NAMES["model_5b"]
|
190 |
)
|
191 |
|
192 |
image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
|
193 |
model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
|
194 |
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
195 |
-
|
196 |
-
"model_6",
|
197 |
-
|
198 |
-
preprocess_resize_224,
|
199 |
-
postprocess_pipeline,
|
200 |
-
CLASS_NAMES["model_6"]
|
201 |
)
|
202 |
|
203 |
image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
|
204 |
model_7 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_7"]).to(device)
|
205 |
clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
|
206 |
-
|
207 |
-
"model_7",
|
208 |
-
|
209 |
-
preprocess_resize_224,
|
210 |
-
postprocess_pipeline,
|
211 |
-
CLASS_NAMES["model_7"]
|
212 |
)
|
213 |
|
214 |
# Generic inference function
|
@@ -218,11 +205,28 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
|
|
218 |
img = entry.preprocess(image)
|
219 |
try:
|
220 |
result = entry.model(img)
|
221 |
-
|
222 |
-
#
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
except Exception as e:
|
225 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
# Update predict_image to use all registered models in order
|
228 |
|
@@ -287,8 +291,12 @@ with gr.Blocks(css="#post-gallery { overflow: hidden !important;} .grid-wrap{ ov
|
|
287 |
|
288 |
|
289 |
with gr.Column(scale=2):
|
290 |
-
# Use Gradio-native Dataframe to display results
|
291 |
-
results_table = gr.Dataframe(
|
|
|
|
|
|
|
|
|
292 |
forensics_gallery = gr.Gallery(label="Post Processed Images", visible=True, columns=[4], rows=[2], container=False, height="auto", object_fit="contain", elem_id="post-gallery")
|
293 |
|
294 |
outputs = [image_output, forensics_gallery, results_table]
|
|
|
16 |
from utils.gradient import gradient_processing
|
17 |
from utils.minmax import preprocess as minmax_preprocess
|
18 |
from utils.ela import genELA as ELA
|
19 |
+
from forensics.registry import register_model, MODEL_REGISTRY, ModelEntry
|
20 |
|
21 |
|
22 |
# Configure logging
|
|
|
107 |
probabilities = softmax(logits)
|
108 |
return {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
109 |
|
110 |
+
# Expand ModelEntry to include metadata
|
111 |
+
# (Assume ModelEntry is updated in registry.py to accept display_name, contributor, model_path)
|
112 |
+
# If not, we will update registry.py accordingly after this.
|
113 |
+
|
114 |
+
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path):
|
115 |
+
entry = ModelEntry(model, preprocess, postprocess, class_names)
|
116 |
+
entry.display_name = display_name
|
117 |
+
entry.contributor = contributor
|
118 |
+
entry.model_path = model_path
|
119 |
+
MODEL_REGISTRY[model_id] = entry
|
120 |
+
|
121 |
# Load and register models (example for two models)
|
122 |
image_processor_1 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_1"], use_fast=True)
|
123 |
model_1 = Swinv2ForImageClassification.from_pretrained(MODEL_PATHS["model_1"]).to(device)
|
124 |
clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
|
125 |
+
register_model_with_metadata(
|
126 |
+
"model_1", clf_1, preprocess_resize_256, postprocess_pipeline, CLASS_NAMES["model_1"],
|
127 |
+
display_name="SwinV2 Based", contributor="haywoodsloan", model_path=MODEL_PATHS["model_1"]
|
|
|
|
|
|
|
128 |
)
|
129 |
|
130 |
clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
|
131 |
+
register_model_with_metadata(
|
132 |
+
"model_2", clf_2, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_2"],
|
133 |
+
display_name="ViT Based", contributor="Heem2", model_path=MODEL_PATHS["model_2"]
|
|
|
|
|
|
|
134 |
)
|
135 |
|
136 |
# Register remaining models
|
|
|
149 |
with torch.no_grad():
|
150 |
outputs = model_3(**inputs)
|
151 |
return outputs
|
152 |
+
register_model_with_metadata(
|
153 |
+
"model_3", model3_infer, preprocess_256, postprocess_logits_model3, CLASS_NAMES["model_3"],
|
154 |
+
display_name="SDXL Dataset", contributor="Organika", model_path=MODEL_PATHS["model_3"]
|
|
|
|
|
|
|
155 |
)
|
156 |
|
157 |
feature_extractor_4 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_4"], device=device)
|
|
|
165 |
logits = outputs.logits.cpu().numpy()[0]
|
166 |
probabilities = softmax(logits)
|
167 |
return {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
168 |
+
register_model_with_metadata(
|
169 |
+
"model_4", model4_infer, preprocess_256, postprocess_logits_model4, CLASS_NAMES["model_4"],
|
170 |
+
display_name="SDXL + FLUX", contributor="cmckinle", model_path=MODEL_PATHS["model_4"]
|
|
|
|
|
|
|
171 |
)
|
172 |
|
173 |
clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
|
174 |
+
register_model_with_metadata(
|
175 |
+
"model_5", clf_5, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_5"],
|
176 |
+
display_name="Vit Based", contributor="prithivMLmods", model_path=MODEL_PATHS["model_5"]
|
|
|
|
|
|
|
177 |
)
|
178 |
|
179 |
clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
|
180 |
+
register_model_with_metadata(
|
181 |
+
"model_5b", clf_5b, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_5b"],
|
182 |
+
display_name="Vit Based, Newer Dataset", contributor="prithivMLmods", model_path=MODEL_PATHS["model_5b"]
|
|
|
|
|
|
|
183 |
)
|
184 |
|
185 |
image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
|
186 |
model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
|
187 |
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
188 |
+
register_model_with_metadata(
|
189 |
+
"model_6", clf_6, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_6"],
|
190 |
+
display_name="Swin, Midj + SDXL", contributor="ideepankarsharma2003", model_path=MODEL_PATHS["model_6"]
|
|
|
|
|
|
|
191 |
)
|
192 |
|
193 |
image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
|
194 |
model_7 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_7"]).to(device)
|
195 |
clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
|
196 |
+
register_model_with_metadata(
|
197 |
+
"model_7", clf_7, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_7"],
|
198 |
+
display_name="ViT", contributor="temp", model_path=MODEL_PATHS["model_7"]
|
|
|
|
|
|
|
199 |
)
|
200 |
|
201 |
# Generic inference function
|
|
|
205 |
img = entry.preprocess(image)
|
206 |
try:
|
207 |
result = entry.model(img)
|
208 |
+
scores = entry.postprocess(result, entry.class_names)
|
209 |
+
# Flatten output for Dataframe: include metadata and both class scores
|
210 |
+
ai_score = scores.get(entry.class_names[0], 0.0)
|
211 |
+
real_score = scores.get(entry.class_names[1], 0.0)
|
212 |
+
label = "AI" if ai_score >= confidence_threshold else ("REAL" if real_score >= confidence_threshold else "UNCERTAIN")
|
213 |
+
return {
|
214 |
+
"Model": entry.display_name,
|
215 |
+
"Contributor": entry.contributor,
|
216 |
+
"HF Model Path": entry.model_path,
|
217 |
+
"AI Score": ai_score,
|
218 |
+
"Real Score": real_score,
|
219 |
+
"Label": label
|
220 |
+
}
|
221 |
except Exception as e:
|
222 |
+
return {
|
223 |
+
"Model": entry.display_name,
|
224 |
+
"Contributor": entry.contributor,
|
225 |
+
"HF Model Path": entry.model_path,
|
226 |
+
"AI Score": None,
|
227 |
+
"Real Score": None,
|
228 |
+
"Label": f"Error: {str(e)}"
|
229 |
+
}
|
230 |
|
231 |
# Update predict_image to use all registered models in order
|
232 |
|
|
|
291 |
|
292 |
|
293 |
with gr.Column(scale=2):
|
294 |
+
# Use Gradio-native Dataframe to display results with headers
|
295 |
+
results_table = gr.Dataframe(
|
296 |
+
label="Model Predictions",
|
297 |
+
headers=["Model", "Contributor", "HF Model Path", "AI Score", "Real Score", "Label"],
|
298 |
+
datatype=["str", "str", "str", "number", "number", "str"]
|
299 |
+
)
|
300 |
forensics_gallery = gr.Gallery(label="Post Processed Images", visible=True, columns=[4], rows=[2], container=False, height="auto", object_fit="contain", elem_id="post-gallery")
|
301 |
|
302 |
outputs = [image_output, forensics_gallery, results_table]
|
forensics/registry.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1 |
-
from typing import Callable, Dict, Any, List
|
2 |
|
3 |
class ModelEntry:
|
4 |
-
def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str]
|
|
|
5 |
self.model = model
|
6 |
self.preprocess = preprocess
|
7 |
self.postprocess = postprocess
|
8 |
self.class_names = class_names
|
|
|
|
|
|
|
9 |
|
10 |
MODEL_REGISTRY: Dict[str, ModelEntry] = {}
|
11 |
|
|
|
1 |
+
from typing import Callable, Dict, Any, List, Optional
|
2 |
|
3 |
class ModelEntry:
|
4 |
+
def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str],
|
5 |
+
display_name: Optional[str] = None, contributor: Optional[str] = None, model_path: Optional[str] = None):
|
6 |
self.model = model
|
7 |
self.preprocess = preprocess
|
8 |
self.postprocess = postprocess
|
9 |
self.class_names = class_names
|
10 |
+
self.display_name = display_name
|
11 |
+
self.contributor = contributor
|
12 |
+
self.model_path = model_path
|
13 |
|
14 |
MODEL_REGISTRY: Dict[str, ModelEntry] = {}
|
15 |
|