from utils.onnx_helpers import postprocess_onnx_output # Add missing import for infer_onnx_model from utils.onnx_helpers import infer_onnx_model # Add missing import for preprocess_onnx_input from utils.onnx_helpers import preprocess_onnx_input """ Model loading and registration logic for OpenSight Deepfake Detection Playground. Handles ONNX, HuggingFace, and Gradio API model registration and metadata. """ from utils.registry import register_model, MODEL_REGISTRY, ModelEntry from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api from transformers import AutoFeatureExtractor, AutoModelForImageClassification import torch import numpy as np from PIL import Image # Model paths and class names (copied from app_mcp.py) MODEL_PATHS = { "model_1": "LPX55/detection-model-1-ONNX", "model_2": "LPX55/detection-model-2-ONNX", "model_3": "LPX55/detection-model-3-ONNX", "model_4": "cmckinle/sdxl-flux-detector_v1.1", "model_5": "LPX55/detection-model-5-ONNX", "model_6": "LPX55/detection-model-6-ONNX", "model_7": "LPX55/detection-model-7-ONNX", "model_8": "aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT" } CLASS_NAMES = { "model_1": ['artificial', 'real'], "model_2": ['AI Image', 'Real Image'], "model_3": ['artificial', 'human'], "model_4": ['AI', 'Real'], "model_5": ['Realism', 'Deepfake'], "model_6": ['ai_gen', 'human'], "model_7": ['Fake', 'Real'], "model_8": ['Fake', 'Real'], } # Cache for ONNX sessions and preprocessors _onnx_model_cache = {} def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None): entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset) MODEL_REGISTRY[model_id] = entry class ONNXModelWrapper: def __init__(self, hf_model_id): self.hf_model_id = hf_model_id self._session = None self._preprocessor_config = None self._model_config = None def load(self): if self._session is None: self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache( self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor ) def __call__(self, image_np): self.load() return infer_onnx_model(self.hf_model_id, image_np, self._model_config) def preprocess(self, image: Image.Image): self.load() return preprocess_onnx_input(image, self._preprocessor_config) def postprocess(self, onnx_output: dict, class_names_from_registry: list): self.load() return postprocess_onnx_output(onnx_output, self._model_config) # The main registration function def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output): for model_key, hf_model_path in MODEL_PATHS.items(): model_num = model_key.replace("model_", "").upper() contributor = "Unknown" architecture = "Unknown" dataset = "TBA" current_class_names = CLASS_NAMES.get(model_key, []) if "ONNX" in hf_model_path: onnx_wrapper_instance = ONNXModelWrapper(hf_model_path) if model_key == "model_1": contributor = "haywoodsloan" architecture = "SwinV2" dataset = "DeepFakeDetection" elif model_key == "model_2": contributor = "Heem2" architecture = "ViT" dataset = "DeepFakeDetection" elif model_key == "model_3": contributor = "Organika" architecture = "VIT" dataset = "SDXL" elif model_key == "model_5": contributor = "prithivMLmods" architecture = "VIT" elif model_key == "model_6": contributor = "ideepankarsharma2003" architecture = "SWINv1" dataset = "SDXL, Midjourney" elif model_key == "model_7": contributor = "date3k2" architecture = "VIT" display_name_parts = [model_num] if architecture and architecture not in ["Unknown"]: display_name_parts.append(architecture) if dataset and dataset not in ["TBA"]: display_name_parts.append(dataset) display_name = "-".join(display_name_parts) + "_ONNX" register_model_with_metadata( model_id=model_key, model=onnx_wrapper_instance, preprocess=onnx_wrapper_instance.preprocess, postprocess=onnx_wrapper_instance.postprocess, class_names=current_class_names, display_name=display_name, contributor=contributor, model_path=hf_model_path, architecture=architecture, dataset=dataset ) elif model_key == "model_8": contributor = "aiwithoutborders-xyz" architecture = "ViT" dataset = "DeepfakeDetection" display_name_parts = [model_num] if architecture and architecture not in ["Unknown"]: display_name_parts.append(architecture) if dataset and dataset not in ["TBA"]: display_name_parts.append(dataset) display_name = "-".join(display_name_parts) register_model_with_metadata( model_id=model_key, model=infer_gradio_api, preprocess=preprocess_gradio_api, postprocess=postprocess_gradio_api, class_names=current_class_names, display_name=display_name, contributor=contributor, model_path=hf_model_path, architecture=architecture, dataset=dataset ) elif model_key == "model_4": contributor = "cmckinle" architecture = "VIT" dataset = "SDXL, FLUX" display_name_parts = [model_num] if architecture and architecture not in ["Unknown"]: display_name_parts.append(architecture) if dataset and dataset not in ["TBA"]: display_name_parts.append(dataset) display_name = "-".join(display_name_parts) current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device) model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device) preprocess_func = preprocess_resize_256 postprocess_func = postprocess_logits def custom_infer(image, processor_local=current_processor, model_local=model_instance): inputs = processor_local(image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model_local(**inputs) return outputs model_instance = custom_infer register_model_with_metadata( model_id=model_key, model=model_instance, preprocess=preprocess_func, postprocess=postprocess_func, class_names=current_class_names, display_name=display_name, contributor=contributor, model_path=hf_model_path, architecture=architecture, dataset=dataset ) else: pass # Fallback for any unhandled models