LPX55
fix: add model paths and class names for ONNX models in model_loader
a4381af
from utils.onnx_helpers import postprocess_onnx_output
# Add missing import for infer_onnx_model
from utils.onnx_helpers import infer_onnx_model
# Add missing import for preprocess_onnx_input
from utils.onnx_helpers import preprocess_onnx_input
"""
Model loading and registration logic for OpenSight Deepfake Detection Playground.
Handles ONNX, HuggingFace, and Gradio API model registration and metadata.
"""
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import numpy as np
from PIL import Image
# Model paths and class names (copied from app_mcp.py)
MODEL_PATHS = {
"model_1": "LPX55/detection-model-1-ONNX",
"model_2": "LPX55/detection-model-2-ONNX",
"model_3": "LPX55/detection-model-3-ONNX",
"model_4": "cmckinle/sdxl-flux-detector_v1.1",
"model_5": "LPX55/detection-model-5-ONNX",
"model_6": "LPX55/detection-model-6-ONNX",
"model_7": "LPX55/detection-model-7-ONNX",
"model_8": "aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT"
}
CLASS_NAMES = {
"model_1": ['artificial', 'real'],
"model_2": ['AI Image', 'Real Image'],
"model_3": ['artificial', 'human'],
"model_4": ['AI', 'Real'],
"model_5": ['Realism', 'Deepfake'],
"model_6": ['ai_gen', 'human'],
"model_7": ['Fake', 'Real'],
"model_8": ['Fake', 'Real'],
}
# Cache for ONNX sessions and preprocessors
_onnx_model_cache = {}
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
MODEL_REGISTRY[model_id] = entry
class ONNXModelWrapper:
def __init__(self, hf_model_id):
self.hf_model_id = hf_model_id
self._session = None
self._preprocessor_config = None
self._model_config = None
def load(self):
if self._session is None:
self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache(
self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor
)
def __call__(self, image_np):
self.load()
return infer_onnx_model(self.hf_model_id, image_np, self._model_config)
def preprocess(self, image: Image.Image):
self.load()
return preprocess_onnx_input(image, self._preprocessor_config)
def postprocess(self, onnx_output: dict, class_names_from_registry: list):
self.load()
return postprocess_onnx_output(onnx_output, self._model_config)
# The main registration function
def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output):
for model_key, hf_model_path in MODEL_PATHS.items():
model_num = model_key.replace("model_", "").upper()
contributor = "Unknown"
architecture = "Unknown"
dataset = "TBA"
current_class_names = CLASS_NAMES.get(model_key, [])
if "ONNX" in hf_model_path:
onnx_wrapper_instance = ONNXModelWrapper(hf_model_path)
if model_key == "model_1":
contributor = "haywoodsloan"
architecture = "SwinV2"
dataset = "DeepFakeDetection"
elif model_key == "model_2":
contributor = "Heem2"
architecture = "ViT"
dataset = "DeepFakeDetection"
elif model_key == "model_3":
contributor = "Organika"
architecture = "VIT"
dataset = "SDXL"
elif model_key == "model_5":
contributor = "prithivMLmods"
architecture = "VIT"
elif model_key == "model_6":
contributor = "ideepankarsharma2003"
architecture = "SWINv1"
dataset = "SDXL, Midjourney"
elif model_key == "model_7":
contributor = "date3k2"
architecture = "VIT"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts) + "_ONNX"
register_model_with_metadata(
model_id=model_key,
model=onnx_wrapper_instance,
preprocess=onnx_wrapper_instance.preprocess,
postprocess=onnx_wrapper_instance.postprocess,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
elif model_key == "model_8":
contributor = "aiwithoutborders-xyz"
architecture = "ViT"
dataset = "DeepfakeDetection"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts)
register_model_with_metadata(
model_id=model_key,
model=infer_gradio_api,
preprocess=preprocess_gradio_api,
postprocess=postprocess_gradio_api,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
elif model_key == "model_4":
contributor = "cmckinle"
architecture = "VIT"
dataset = "SDXL, FLUX"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts)
current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device)
model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device)
preprocess_func = preprocess_resize_256
postprocess_func = postprocess_logits
def custom_infer(image, processor_local=current_processor, model_local=model_instance):
inputs = processor_local(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model_local(**inputs)
return outputs
model_instance = custom_infer
register_model_with_metadata(
model_id=model_key,
model=model_instance,
preprocess=preprocess_func,
postprocess=postprocess_func,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
else:
pass # Fallback for any unhandled models