File size: 5,951 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

"""onnx shape inference. Shape inference is not guaranteed to be

complete.



"""

from __future__ import annotations

import os
from typing import Sequence

import onnx
import onnx.onnx_cpp2py_export.shape_inference as C  # noqa: N812
from onnx import AttributeProto, FunctionProto, ModelProto, TypeProto


def infer_shapes(

    model: ModelProto | bytes,

    check_type: bool = False,

    strict_mode: bool = False,

    data_prop: bool = False,

) -> ModelProto:
    """Apply shape inference to the provided ModelProto.



    Inferred shapes are added to the value_info field of the graph.



    If the inferred values conflict with values already provided in the

    graph, that means that the provided values are invalid (or there is a

    bug in shape inference), and the result is unspecified.



    Arguments:

        model: ModelProto.

        check_type: Checks the type-equality for input and output.

        strict_mode: Stricter shape inference, it will throw errors if any;

            Otherwise, simply stop if any error.

        data_prop: Enables data propagation for limited operators to perform shape computation.



    Returns:

        (ModelProto) model with inferred shape information

    """
    if isinstance(model, (ModelProto, bytes)):
        model_str = model if isinstance(model, bytes) else model.SerializeToString()
        inferred_model_str = C.infer_shapes(
            model_str, check_type, strict_mode, data_prop
        )
        return onnx.load_from_string(inferred_model_str)
    if isinstance(model, str):
        raise TypeError(
            "infer_shapes only accepts ModelProto or bytes,"
            "you can use infer_shapes_path for the model path (String)."
        )

    raise TypeError(
        f"infer_shapes only accepts ModelProto or bytes, incorrect type: {type(model)}"
    )


def infer_shapes_path(

    model_path: str | os.PathLike,

    output_path: str | os.PathLike = "",

    check_type: bool = False,

    strict_mode: bool = False,

    data_prop: bool = False,

) -> None:
    """Take model path for shape_inference.



    This function is the same as :func:`infer_shape` but supports >2GB models.

    The function outputs the inferred model to the `output_path`. The original model path

    is used if not specified.

    """
    if isinstance(model_path, ModelProto):
        raise TypeError(
            "infer_shapes_path only accepts model Path (String),"
            "you can use infer_shapes for the ModelProto."
        )
    try:
        model_path = os.fspath(model_path)
    except TypeError as exp:
        raise TypeError(
            "infer_shapes_path only accepts model path as a string or PathLike, "
            f"incorrect model path type: {type(model_path)}"
        ) from exp
    try:
        output_path = os.fspath(output_path)
    except TypeError as exp:
        raise TypeError(
            "infer_shapes_path only accepts output path as a string or PathLike, "
            f"incorrect output path type: {type(output_path)}"
        ) from exp

    if output_path == "":
        output_path = model_path
    C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)


def infer_node_outputs(

    schema: onnx.defs.OpSchema,

    node: onnx.NodeProto,

    input_types: dict[str, onnx.TypeProto],

    input_data: dict[str, onnx.TensorProto] | None = None,

    input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None,

    opset_imports: list[onnx.OperatorSetIdProto] | None = None,

    ir_version: int = onnx.IR_VERSION,

) -> dict[str, onnx.TypeProto]:
    if not schema.has_type_and_shape_inference_function:  # type: ignore
        return {}
    if input_data is None:
        input_data = {}
    if input_sparse_data is None:
        input_sparse_data = {}
    if opset_imports is None:
        passed_opset_imports = {}
    else:
        passed_opset_imports = {opset.domain: opset.version for opset in opset_imports}

    # catch KeyError if node's input does not exist in input_types
    passed_input_types = {
        key: input_types[key].SerializeToString() for key in node.input
    }
    # input_types will also be used as outer_scope_value_types so do not filter by node's input here
    for key in input_types:
        if key not in passed_input_types:
            passed_input_types[key] = input_types[key].SerializeToString()
    passed_input_data = {
        key: input_data[key].SerializeToString()
        for key in node.input
        if key in input_data
    }
    passed_sparse_input_data = {
        key: input_sparse_data[key].SerializeToString()
        for key in node.input
        if key in input_sparse_data
    }

    outputs = schema._infer_node_outputs(
        node.SerializeToString(),
        passed_input_types,
        passed_input_data,
        passed_sparse_input_data,
        passed_opset_imports,
        ir_version,
    )  # type: ignore[call-arg]
    return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()}


def infer_function_output_types(

    function: FunctionProto,

    input_types: Sequence[TypeProto],

    attributes: Sequence[AttributeProto],

) -> list[TypeProto]:
    """Apply type-and-shape-inference to given function body, with given input types

    and given input attribute values.

    """
    result = C.infer_function_output_types(
        function.SerializeToString(),
        [x.SerializeToString() for x in input_types],
        [x.SerializeToString() for x in attributes],
    )

    def to_type_proto(x) -> TypeProto:
        type_proto = onnx.TypeProto()
        type_proto.ParseFromString(x)
        return type_proto

    return [to_type_proto(x) for x in result]


InferenceError = C.InferenceError