Spaces:
Sleeping
Sleeping
File size: 8,069 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import warnings
__all__ = [
"registry",
]
import typing
from typing import Any, Collection, Optional, Protocol, TypeVar
import google.protobuf.json_format
import google.protobuf.message
import google.protobuf.text_format
import onnx
_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message)
# Encoding used for serializing and deserializing text files
_ENCODING = "utf-8"
class ProtoSerializer(Protocol):
"""A serializer-deserializer to and from in-memory Protocol Buffers representations."""
# Format supported by the serializer. E.g. "protobuf"
supported_format: str
# File extensions supported by the serializer. E.g. frozenset({".onnx", ".pb"})
# Be careful to include the dot in the file extension.
file_extensions: Collection[str]
# NOTE: The methods defined are serialize_proto and deserialize_proto and not the
# more generic serialize and deserialize to leave space for future protocols
# that are defined to serialize/deserialize the ONNX in memory IR.
# This way a class can implement both protocols.
def serialize_proto(self, proto: _Proto) -> Any:
"""Serialize a in-memory proto to a serialized data type."""
def deserialize_proto(self, serialized: Any, proto: _Proto) -> _Proto:
"""Parse a serialized data type into a in-memory proto."""
class _Registry:
def __init__(self) -> None:
self._serializers: dict[str, ProtoSerializer] = {}
# A mapping from file extension to format
self._extension_to_format: dict[str, str] = {}
def register(self, serializer: ProtoSerializer) -> None:
self._serializers[serializer.supported_format] = serializer
self._extension_to_format.update(
{ext: serializer.supported_format for ext in serializer.file_extensions}
)
def get(self, fmt: str) -> ProtoSerializer:
"""Get a serializer for a format.
Args:
fmt: The format to get a serializer for.
Returns:
ProtoSerializer: The serializer for the format.
Raises:
ValueError: If the format is not supported.
"""
try:
return self._serializers[fmt]
except KeyError:
raise ValueError(
f"Unsupported format: '{fmt}'. Supported formats are: {self._serializers.keys()}"
) from None
def get_format_from_file_extension(self, file_extension: str) -> str | None:
"""Get the corresponding format from a file extension.
Args:
file_extension: The file extension to get a format for.
Returns:
The format for the file extension, or None if not found.
"""
return self._extension_to_format.get(file_extension)
class _ProtobufSerializer(ProtoSerializer):
"""Serialize and deserialize protobuf message."""
supported_format = "protobuf"
file_extensions = frozenset({".onnx", ".pb"})
def serialize_proto(self, proto: _Proto) -> bytes:
if hasattr(proto, "SerializeToString") and callable(proto.SerializeToString):
try:
result = proto.SerializeToString()
except ValueError as e:
if proto.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
raise ValueError(
"The proto size is larger than the 2 GB limit. "
"Please use save_as_external_data to save tensors separately from the model file."
) from e
raise
return result # type: ignore
raise TypeError(
f"No SerializeToString method is detected.\ntype is {type(proto)}"
)
def deserialize_proto(self, serialized: bytes, proto: _Proto) -> _Proto:
if not isinstance(serialized, bytes):
raise TypeError(
f"Parameter 'serialized' must be bytes, but got type: {type(serialized)}"
)
decoded = typing.cast(Optional[int], proto.ParseFromString(serialized))
if decoded is not None and decoded != len(serialized):
raise google.protobuf.message.DecodeError(
f"Protobuf decoding consumed too few bytes: {decoded} out of {len(serialized)}"
)
return proto
class _TextProtoSerializer(ProtoSerializer):
"""Serialize and deserialize text proto."""
supported_format = "textproto"
file_extensions = frozenset({".textproto", ".prototxt", ".pbtxt"})
def serialize_proto(self, proto: _Proto) -> bytes:
textproto = google.protobuf.text_format.MessageToString(proto)
return textproto.encode(_ENCODING)
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
if not isinstance(serialized, (bytes, str)):
raise TypeError(
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
)
if isinstance(serialized, bytes):
serialized = serialized.decode(_ENCODING)
assert isinstance(serialized, str)
return google.protobuf.text_format.Parse(serialized, proto)
class _JsonSerializer(ProtoSerializer):
"""Serialize and deserialize JSON."""
supported_format = "json"
file_extensions = frozenset({".json", ".onnxjson"})
def serialize_proto(self, proto: _Proto) -> bytes:
json_message = google.protobuf.json_format.MessageToJson(
proto, preserving_proto_field_name=True
)
return json_message.encode(_ENCODING)
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
if not isinstance(serialized, (bytes, str)):
raise TypeError(
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
)
if isinstance(serialized, bytes):
serialized = serialized.decode(_ENCODING)
assert isinstance(serialized, str)
return google.protobuf.json_format.Parse(serialized, proto)
class _TextualSerializer(ProtoSerializer):
"""Serialize and deserialize the ONNX textual representation."""
supported_format = "onnxtxt"
file_extensions = frozenset({".onnxtxt"})
def serialize_proto(self, proto: _Proto) -> bytes:
text = onnx.printer.to_text(proto) # type: ignore[arg-type]
return text.encode(_ENCODING)
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto:
warnings.warn(
"The onnxtxt format is experimental. Please report any errors to the ONNX GitHub repository.",
stacklevel=2,
)
if not isinstance(serialized, (bytes, str)):
raise TypeError(
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}"
)
if isinstance(serialized, bytes):
text = serialized.decode(_ENCODING)
else:
text = serialized
if isinstance(proto, onnx.ModelProto):
return onnx.parser.parse_model(text) # type: ignore[return-value]
if isinstance(proto, onnx.GraphProto):
return onnx.parser.parse_graph(text) # type: ignore[return-value]
if isinstance(proto, onnx.FunctionProto):
return onnx.parser.parse_function(text) # type: ignore[return-value]
if isinstance(proto, onnx.NodeProto):
return onnx.parser.parse_node(text) # type: ignore[return-value]
raise ValueError(f"Unsupported proto type: {type(proto)}")
# Register default serializers
registry = _Registry()
registry.register(_ProtobufSerializer())
registry.register(_TextProtoSerializer())
registry.register(_JsonSerializer())
registry.register(_TextualSerializer())
|