Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
"""onnx shape inference. Shape inference is not guaranteed to be | |
complete. | |
""" | |
from __future__ import annotations | |
import os | |
from typing import Sequence | |
import onnx | |
import onnx.onnx_cpp2py_export.shape_inference as C # noqa: N812 | |
from onnx import AttributeProto, FunctionProto, ModelProto, TypeProto | |
def infer_shapes( | |
model: ModelProto | bytes, | |
check_type: bool = False, | |
strict_mode: bool = False, | |
data_prop: bool = False, | |
) -> ModelProto: | |
"""Apply shape inference to the provided ModelProto. | |
Inferred shapes are added to the value_info field of the graph. | |
If the inferred values conflict with values already provided in the | |
graph, that means that the provided values are invalid (or there is a | |
bug in shape inference), and the result is unspecified. | |
Arguments: | |
model: ModelProto. | |
check_type: Checks the type-equality for input and output. | |
strict_mode: Stricter shape inference, it will throw errors if any; | |
Otherwise, simply stop if any error. | |
data_prop: Enables data propagation for limited operators to perform shape computation. | |
Returns: | |
(ModelProto) model with inferred shape information | |
""" | |
if isinstance(model, (ModelProto, bytes)): | |
model_str = model if isinstance(model, bytes) else model.SerializeToString() | |
inferred_model_str = C.infer_shapes( | |
model_str, check_type, strict_mode, data_prop | |
) | |
return onnx.load_from_string(inferred_model_str) | |
if isinstance(model, str): | |
raise TypeError( | |
"infer_shapes only accepts ModelProto or bytes," | |
"you can use infer_shapes_path for the model path (String)." | |
) | |
raise TypeError( | |
f"infer_shapes only accepts ModelProto or bytes, incorrect type: {type(model)}" | |
) | |
def infer_shapes_path( | |
model_path: str | os.PathLike, | |
output_path: str | os.PathLike = "", | |
check_type: bool = False, | |
strict_mode: bool = False, | |
data_prop: bool = False, | |
) -> None: | |
"""Take model path for shape_inference. | |
This function is the same as :func:`infer_shape` but supports >2GB models. | |
The function outputs the inferred model to the `output_path`. The original model path | |
is used if not specified. | |
""" | |
if isinstance(model_path, ModelProto): | |
raise TypeError( | |
"infer_shapes_path only accepts model Path (String)," | |
"you can use infer_shapes for the ModelProto." | |
) | |
try: | |
model_path = os.fspath(model_path) | |
except TypeError as exp: | |
raise TypeError( | |
"infer_shapes_path only accepts model path as a string or PathLike, " | |
f"incorrect model path type: {type(model_path)}" | |
) from exp | |
try: | |
output_path = os.fspath(output_path) | |
except TypeError as exp: | |
raise TypeError( | |
"infer_shapes_path only accepts output path as a string or PathLike, " | |
f"incorrect output path type: {type(output_path)}" | |
) from exp | |
if output_path == "": | |
output_path = model_path | |
C.infer_shapes_path(model_path, output_path, check_type, strict_mode, data_prop) | |
def infer_node_outputs( | |
schema: onnx.defs.OpSchema, | |
node: onnx.NodeProto, | |
input_types: dict[str, onnx.TypeProto], | |
input_data: dict[str, onnx.TensorProto] | None = None, | |
input_sparse_data: dict[str, onnx.SparseTensorProto] | None = None, | |
opset_imports: list[onnx.OperatorSetIdProto] | None = None, | |
ir_version: int = onnx.IR_VERSION, | |
) -> dict[str, onnx.TypeProto]: | |
if not schema.has_type_and_shape_inference_function: # type: ignore | |
return {} | |
if input_data is None: | |
input_data = {} | |
if input_sparse_data is None: | |
input_sparse_data = {} | |
if opset_imports is None: | |
passed_opset_imports = {} | |
else: | |
passed_opset_imports = {opset.domain: opset.version for opset in opset_imports} | |
# catch KeyError if node's input does not exist in input_types | |
passed_input_types = { | |
key: input_types[key].SerializeToString() for key in node.input | |
} | |
# input_types will also be used as outer_scope_value_types so do not filter by node's input here | |
for key in input_types: | |
if key not in passed_input_types: | |
passed_input_types[key] = input_types[key].SerializeToString() | |
passed_input_data = { | |
key: input_data[key].SerializeToString() | |
for key in node.input | |
if key in input_data | |
} | |
passed_sparse_input_data = { | |
key: input_sparse_data[key].SerializeToString() | |
for key in node.input | |
if key in input_sparse_data | |
} | |
outputs = schema._infer_node_outputs( | |
node.SerializeToString(), | |
passed_input_types, | |
passed_input_data, | |
passed_sparse_input_data, | |
passed_opset_imports, | |
ir_version, | |
) # type: ignore[call-arg] | |
return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()} | |
def infer_function_output_types( | |
function: FunctionProto, | |
input_types: Sequence[TypeProto], | |
attributes: Sequence[AttributeProto], | |
) -> list[TypeProto]: | |
"""Apply type-and-shape-inference to given function body, with given input types | |
and given input attribute values. | |
""" | |
result = C.infer_function_output_types( | |
function.SerializeToString(), | |
[x.SerializeToString() for x in input_types], | |
[x.SerializeToString() for x in attributes], | |
) | |
def to_type_proto(x) -> TypeProto: | |
type_proto = onnx.TypeProto() | |
type_proto.ParseFromString(x) | |
return type_proto | |
return [to_type_proto(x) for x in result] | |
InferenceError = C.InferenceError | |