Spaces:
Sleeping
Sleeping
File size: 5,951 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
"""onnx shape inference. Shape inference is not guaranteed to be
complete.
"""
from __future__ import annotations
import os
from typing import Sequence
import onnx
import onnx.onnx_cpp2py_export.shape_inference as C # noqa: N812
from onnx import AttributeProto, FunctionProto, ModelProto, TypeProto
def infer_shapes(
model: ModelProto | bytes,
check_type: bool = False,
strict_mode: bool = False,
data_prop: bool = False,
) -> ModelProto:
"""Apply shape inference to the provided ModelProto.
Inferred shapes are added to the value_info field of the graph.
If the inferred values conflict with values already provided in the
graph, that means that the provided values are invalid (or there is a
bug in shape inference), and the result is unspecified.
Arguments:
model: ModelProto.
check_type: Checks the type-equality for input and output.
strict_mode: Stricter shape inference, it will throw errors if any;
Otherwise, simply stop if any error.
data_prop: Enables data propagation for limited operators to perform shape computation.
Returns:
(ModelProto) model with inferred shape information
"""
if isinstance(model, (ModelProto, bytes)):
model_str = model if isinstance(model, bytes) else model.SerializeToString()
inferred_model_str = C.infer_shapes(
model_str, check_type, strict_mode, data_prop
)
return onnx.load_from_string(inferred_model_str)
if isinstance(model, str):
raise TypeError(
"infer_shapes only accepts ModelProto or bytes,"
"you can use infer_shapes_path for the model path (String)."
)
raise TypeError(
f"infer_shapes only accepts ModelProto or bytes, incorrect type: {type(model)}"
)
def infer_shapes_path(
model_path: str | os.PathLike,
output_path: str | os.PathLike = "",
check_type: bool = False,
strict_mode: bool = False,
data_prop: bool = False,
) -> None:
"""Take model path for shape_inference.
This function is the same as :func:`infer_shape` but supports >2GB models.
The function outputs the inferred model to the `output_path`. The original model path
is used if not specified.
"""
if isinstance(model_path, ModelProto):
raise TypeError(
"infer_shapes_path only accepts model Path (String),"
"you can use infer_shapes for the ModelProto."
)
try:
model_path = os.fspath(model_path)
except TypeError as exp:
raise TypeError(
"infer_shapes_path only accepts model path as a string or PathLike, "
f"incorrect model path type: {type(model_path)}"
) from exp
try:
output_path = os.fspath(output_path)
except TypeError as exp:
raise TypeError(
"infer_shapes_path only accepts output path as a string or PathLike, "
f"incorrect output path type: {type(output_path)}"
) from exp
if output_path == "":
output_path = model_path
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop)
def infer_node_outputs(
schema: onnx.defs.OpSchema,
node: onnx.NodeProto,
input_types: dict[str, onnx.TypeProto],
input_data: dict[str, onnx.TensorProto] | None = None,
input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None,
opset_imports: list[onnx.OperatorSetIdProto] | None = None,
ir_version: int = onnx.IR_VERSION,
) -> dict[str, onnx.TypeProto]:
if not schema.has_type_and_shape_inference_function: # type: ignore
return {}
if input_data is None:
input_data = {}
if input_sparse_data is None:
input_sparse_data = {}
if opset_imports is None:
passed_opset_imports = {}
else:
passed_opset_imports = {opset.domain: opset.version for opset in opset_imports}
# catch KeyError if node's input does not exist in input_types
passed_input_types = {
key: input_types[key].SerializeToString() for key in node.input
}
# input_types will also be used as outer_scope_value_types so do not filter by node's input here
for key in input_types:
if key not in passed_input_types:
passed_input_types[key] = input_types[key].SerializeToString()
passed_input_data = {
key: input_data[key].SerializeToString()
for key in node.input
if key in input_data
}
passed_sparse_input_data = {
key: input_sparse_data[key].SerializeToString()
for key in node.input
if key in input_sparse_data
}
outputs = schema._infer_node_outputs(
node.SerializeToString(),
passed_input_types,
passed_input_data,
passed_sparse_input_data,
passed_opset_imports,
ir_version,
) # type: ignore[call-arg]
return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()}
def infer_function_output_types(
function: FunctionProto,
input_types: Sequence[TypeProto],
attributes: Sequence[AttributeProto],
) -> list[TypeProto]:
"""Apply type-and-shape-inference to given function body, with given input types
and given input attribute values.
"""
result = C.infer_function_output_types(
function.SerializeToString(),
[x.SerializeToString() for x in input_types],
[x.SerializeToString() for x in attributes],
)
def to_type_proto(x) -> TypeProto:
type_proto = onnx.TypeProto()
type_proto.ParseFromString(x)
return type_proto
return [to_type_proto(x) for x in result]
InferenceError = C.InferenceError
|