File size: 3,475 Bytes
ceff40b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cb6734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Shared ONNX inference function for use by app.py and model_loader.py
def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
    from .onnx_model_loader import get_onnx_model_from_cache, load_onnx_model_and_preprocessor
    from .utils import softmax
    import numpy as np
    import logging
    logger = logging.getLogger(__name__)
    _onnx_model_cache = {}
    try:
        ort_session, _, _ = get_onnx_model_from_cache(hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor)
        for input_meta in ort_session.get_inputs():
            logger.info(f"Debug: ONNX model expected input name: {input_meta.name}, shape: {input_meta.shape}, type: {input_meta.type}")
        logger.info(f"Debug: preprocessed_image_np shape: {preprocessed_image_np.shape}")
        ort_inputs = {ort_session.get_inputs()[0].name: preprocessed_image_np}
        ort_outputs = ort_session.run(None, ort_inputs)
        logits = ort_outputs[0]
        logger.info(f"Debug: logits type: {type(logits)}, shape: {logits.shape}")
        probabilities = softmax(logits[0])
        return {"logits": logits, "probabilities": probabilities}
    except Exception as e:
        logger.error(f"Error during ONNX inference for {hf_model_id}: {e}")
        return {"logits": np.array([]), "probabilities": np.array([])}
import numpy as np
from torchvision import transforms
from PIL import Image
import logging

def preprocess_onnx_input(image, preprocessor_config):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224})
    crop_size = preprocessor_config.get('crop_size', initial_resize_size['height'])
    mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406])
    std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])
    input_tensor = transform(image)
    return input_tensor.unsqueeze(0).cpu().numpy()

def postprocess_onnx_output(onnx_output, model_config):
    logger = logging.getLogger(__name__)
    class_names_map = model_config.get('id2label')
    if class_names_map:
        class_names = [class_names_map[k] for k in sorted(class_names_map.keys())]
    elif model_config.get('num_classes') == 1:
        class_names = ['Fake', 'Real']
    else:
        class_names = {0: 'Fake', 1: 'Real'}
        class_names = [class_names[i] for i in sorted(class_names.keys())]
    probabilities = onnx_output.get("probabilities")
    if probabilities is not None:
        if model_config.get('num_classes') == 1 and len(probabilities) == 2:
            fake_prob = float(probabilities[0])
            real_prob = float(probabilities[1])
            return {class_names[0]: fake_prob, class_names[1]: real_prob}
        elif len(probabilities) == len(class_names):
            return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
        else:
            logger.warning("ONNX post-processing: Probabilities length mismatch with class names.")
            return {name: 0.0 for name in class_names}
    else:
        logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.")
        return {name: 0.0 for name in class_names}