|
from utils.onnx_helpers import postprocess_onnx_output |
|
|
|
from utils.onnx_helpers import infer_onnx_model |
|
|
|
from utils.onnx_helpers import preprocess_onnx_input |
|
""" |
|
Model loading and registration logic for OpenSight Deepfake Detection Playground. |
|
Handles ONNX, HuggingFace, and Gradio API model registration and metadata. |
|
""" |
|
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry |
|
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache |
|
from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api |
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
|
|
MODEL_PATHS = { |
|
"model_1": "LPX55/detection-model-1-ONNX", |
|
"model_2": "LPX55/detection-model-2-ONNX", |
|
"model_3": "LPX55/detection-model-3-ONNX", |
|
"model_4": "cmckinle/sdxl-flux-detector_v1.1", |
|
"model_5": "LPX55/detection-model-5-ONNX", |
|
"model_6": "LPX55/detection-model-6-ONNX", |
|
"model_7": "LPX55/detection-model-7-ONNX", |
|
"model_8": "aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT" |
|
} |
|
|
|
CLASS_NAMES = { |
|
"model_1": ['artificial', 'real'], |
|
"model_2": ['AI Image', 'Real Image'], |
|
"model_3": ['artificial', 'human'], |
|
"model_4": ['AI', 'Real'], |
|
"model_5": ['Realism', 'Deepfake'], |
|
"model_6": ['ai_gen', 'human'], |
|
"model_7": ['Fake', 'Real'], |
|
"model_8": ['Fake', 'Real'], |
|
} |
|
|
|
|
|
|
|
_onnx_model_cache = {} |
|
|
|
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None): |
|
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset) |
|
MODEL_REGISTRY[model_id] = entry |
|
|
|
class ONNXModelWrapper: |
|
def __init__(self, hf_model_id): |
|
self.hf_model_id = hf_model_id |
|
self._session = None |
|
self._preprocessor_config = None |
|
self._model_config = None |
|
|
|
def load(self): |
|
if self._session is None: |
|
self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache( |
|
self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor |
|
) |
|
|
|
def __call__(self, image_np): |
|
self.load() |
|
return infer_onnx_model(self.hf_model_id, image_np, self._model_config) |
|
|
|
def preprocess(self, image: Image.Image): |
|
self.load() |
|
return preprocess_onnx_input(image, self._preprocessor_config) |
|
|
|
def postprocess(self, onnx_output: dict, class_names_from_registry: list): |
|
self.load() |
|
return postprocess_onnx_output(onnx_output, self._model_config) |
|
|
|
|
|
|
|
def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output): |
|
for model_key, hf_model_path in MODEL_PATHS.items(): |
|
model_num = model_key.replace("model_", "").upper() |
|
contributor = "Unknown" |
|
architecture = "Unknown" |
|
dataset = "TBA" |
|
current_class_names = CLASS_NAMES.get(model_key, []) |
|
if "ONNX" in hf_model_path: |
|
onnx_wrapper_instance = ONNXModelWrapper(hf_model_path) |
|
if model_key == "model_1": |
|
contributor = "haywoodsloan" |
|
architecture = "SwinV2" |
|
dataset = "DeepFakeDetection" |
|
elif model_key == "model_2": |
|
contributor = "Heem2" |
|
architecture = "ViT" |
|
dataset = "DeepFakeDetection" |
|
elif model_key == "model_3": |
|
contributor = "Organika" |
|
architecture = "VIT" |
|
dataset = "SDXL" |
|
elif model_key == "model_5": |
|
contributor = "prithivMLmods" |
|
architecture = "VIT" |
|
elif model_key == "model_6": |
|
contributor = "ideepankarsharma2003" |
|
architecture = "SWINv1" |
|
dataset = "SDXL, Midjourney" |
|
elif model_key == "model_7": |
|
contributor = "date3k2" |
|
architecture = "VIT" |
|
display_name_parts = [model_num] |
|
if architecture and architecture not in ["Unknown"]: |
|
display_name_parts.append(architecture) |
|
if dataset and dataset not in ["TBA"]: |
|
display_name_parts.append(dataset) |
|
display_name = "-".join(display_name_parts) + "_ONNX" |
|
register_model_with_metadata( |
|
model_id=model_key, |
|
model=onnx_wrapper_instance, |
|
preprocess=onnx_wrapper_instance.preprocess, |
|
postprocess=onnx_wrapper_instance.postprocess, |
|
class_names=current_class_names, |
|
display_name=display_name, |
|
contributor=contributor, |
|
model_path=hf_model_path, |
|
architecture=architecture, |
|
dataset=dataset |
|
) |
|
elif model_key == "model_8": |
|
contributor = "aiwithoutborders-xyz" |
|
architecture = "ViT" |
|
dataset = "DeepfakeDetection" |
|
display_name_parts = [model_num] |
|
if architecture and architecture not in ["Unknown"]: |
|
display_name_parts.append(architecture) |
|
if dataset and dataset not in ["TBA"]: |
|
display_name_parts.append(dataset) |
|
display_name = "-".join(display_name_parts) |
|
register_model_with_metadata( |
|
model_id=model_key, |
|
model=infer_gradio_api, |
|
preprocess=preprocess_gradio_api, |
|
postprocess=postprocess_gradio_api, |
|
class_names=current_class_names, |
|
display_name=display_name, |
|
contributor=contributor, |
|
model_path=hf_model_path, |
|
architecture=architecture, |
|
dataset=dataset |
|
) |
|
elif model_key == "model_4": |
|
contributor = "cmckinle" |
|
architecture = "VIT" |
|
dataset = "SDXL, FLUX" |
|
display_name_parts = [model_num] |
|
if architecture and architecture not in ["Unknown"]: |
|
display_name_parts.append(architecture) |
|
if dataset and dataset not in ["TBA"]: |
|
display_name_parts.append(dataset) |
|
display_name = "-".join(display_name_parts) |
|
current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device) |
|
model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device) |
|
preprocess_func = preprocess_resize_256 |
|
postprocess_func = postprocess_logits |
|
def custom_infer(image, processor_local=current_processor, model_local=model_instance): |
|
inputs = processor_local(image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model_local(**inputs) |
|
return outputs |
|
model_instance = custom_infer |
|
register_model_with_metadata( |
|
model_id=model_key, |
|
model=model_instance, |
|
preprocess=preprocess_func, |
|
postprocess=postprocess_func, |
|
class_names=current_class_names, |
|
display_name=display_name, |
|
contributor=contributor, |
|
model_path=hf_model_path, |
|
architecture=architecture, |
|
dataset=dataset |
|
) |
|
else: |
|
pass |
|
|