Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
from __future__ import annotations | |
import warnings | |
__all__ = [ | |
"registry", | |
] | |
import typing | |
from typing import Any, Collection, Optional, Protocol, TypeVar | |
import google.protobuf.json_format | |
import google.protobuf.message | |
import google.protobuf.text_format | |
import onnx | |
_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message) | |
# Encoding used for serializing and deserializing text files | |
_ENCODING = "utf-8" | |
class ProtoSerializer(Protocol): | |
"""A serializer-deserializer to and from in-memory Protocol Buffers representations.""" | |
# Format supported by the serializer. E.g. "protobuf" | |
supported_format: str | |
# File extensions supported by the serializer. E.g. frozenset({".onnx", ".pb"}) | |
# Be careful to include the dot in the file extension. | |
file_extensions: Collection[str] | |
# NOTE: The methods defined are serialize_proto and deserialize_proto and not the | |
# more generic serialize and deserialize to leave space for future protocols | |
# that are defined to serialize/deserialize the ONNX in memory IR. | |
# This way a class can implement both protocols. | |
def serialize_proto(self, proto: _Proto) -> Any: | |
"""Serialize a in-memory proto to a serialized data type.""" | |
def deserialize_proto(self, serialized: Any, proto: _Proto) -> _Proto: | |
"""Parse a serialized data type into a in-memory proto.""" | |
class _Registry: | |
def __init__(self) -> None: | |
self._serializers: dict[str, ProtoSerializer] = {} | |
# A mapping from file extension to format | |
self._extension_to_format: dict[str, str] = {} | |
def register(self, serializer: ProtoSerializer) -> None: | |
self._serializers[serializer.supported_format] = serializer | |
self._extension_to_format.update( | |
{ext: serializer.supported_format for ext in serializer.file_extensions} | |
) | |
def get(self, fmt: str) -> ProtoSerializer: | |
"""Get a serializer for a format. | |
Args: | |
fmt: The format to get a serializer for. | |
Returns: | |
ProtoSerializer: The serializer for the format. | |
Raises: | |
ValueError: If the format is not supported. | |
""" | |
try: | |
return self._serializers[fmt] | |
except KeyError: | |
raise ValueError( | |
f"Unsupported format: '{fmt}'. Supported formats are: {self._serializers.keys()}" | |
) from None | |
def get_format_from_file_extension(self, file_extension: str) -> str | None: | |
"""Get the corresponding format from a file extension. | |
Args: | |
file_extension: The file extension to get a format for. | |
Returns: | |
The format for the file extension, or None if not found. | |
""" | |
return self._extension_to_format.get(file_extension) | |
class _ProtobufSerializer(ProtoSerializer): | |
"""Serialize and deserialize protobuf message.""" | |
supported_format = "protobuf" | |
file_extensions = frozenset({".onnx", ".pb"}) | |
def serialize_proto(self, proto: _Proto) -> bytes: | |
if hasattr(proto, "SerializeToString") and callable(proto.SerializeToString): | |
try: | |
result = proto.SerializeToString() | |
except ValueError as e: | |
if proto.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: | |
raise ValueError( | |
"The proto size is larger than the 2 GB limit. " | |
"Please use save_as_external_data to save tensors separately from the model file." | |
) from e | |
raise | |
return result # type: ignore | |
raise TypeError( | |
f"No SerializeToString method is detected.\ntype is {type(proto)}" | |
) | |
def deserialize_proto(self, serialized: bytes, proto: _Proto) -> _Proto: | |
if not isinstance(serialized, bytes): | |
raise TypeError( | |
f"Parameter 'serialized' must be bytes, but got type: {type(serialized)}" | |
) | |
decoded = typing.cast(Optional[int], proto.ParseFromString(serialized)) | |
if decoded is not None and decoded != len(serialized): | |
raise google.protobuf.message.DecodeError( | |
f"Protobuf decoding consumed too few bytes: {decoded} out of {len(serialized)}" | |
) | |
return proto | |
class _TextProtoSerializer(ProtoSerializer): | |
"""Serialize and deserialize text proto.""" | |
supported_format = "textproto" | |
file_extensions = frozenset({".textproto", ".prototxt", ".pbtxt"}) | |
def serialize_proto(self, proto: _Proto) -> bytes: | |
textproto = google.protobuf.text_format.MessageToString(proto) | |
return textproto.encode(_ENCODING) | |
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto: | |
if not isinstance(serialized, (bytes, str)): | |
raise TypeError( | |
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}" | |
) | |
if isinstance(serialized, bytes): | |
serialized = serialized.decode(_ENCODING) | |
assert isinstance(serialized, str) | |
return google.protobuf.text_format.Parse(serialized, proto) | |
class _JsonSerializer(ProtoSerializer): | |
"""Serialize and deserialize JSON.""" | |
supported_format = "json" | |
file_extensions = frozenset({".json", ".onnxjson"}) | |
def serialize_proto(self, proto: _Proto) -> bytes: | |
json_message = google.protobuf.json_format.MessageToJson( | |
proto, preserving_proto_field_name=True | |
) | |
return json_message.encode(_ENCODING) | |
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto: | |
if not isinstance(serialized, (bytes, str)): | |
raise TypeError( | |
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}" | |
) | |
if isinstance(serialized, bytes): | |
serialized = serialized.decode(_ENCODING) | |
assert isinstance(serialized, str) | |
return google.protobuf.json_format.Parse(serialized, proto) | |
class _TextualSerializer(ProtoSerializer): | |
"""Serialize and deserialize the ONNX textual representation.""" | |
supported_format = "onnxtxt" | |
file_extensions = frozenset({".onnxtxt"}) | |
def serialize_proto(self, proto: _Proto) -> bytes: | |
text = onnx.printer.to_text(proto) # type: ignore[arg-type] | |
return text.encode(_ENCODING) | |
def deserialize_proto(self, serialized: bytes | str, proto: _Proto) -> _Proto: | |
warnings.warn( | |
"The onnxtxt format is experimental. Please report any errors to the ONNX GitHub repository.", | |
stacklevel=2, | |
) | |
if not isinstance(serialized, (bytes, str)): | |
raise TypeError( | |
f"Parameter 'serialized' must be bytes or str, but got type: {type(serialized)}" | |
) | |
if isinstance(serialized, bytes): | |
text = serialized.decode(_ENCODING) | |
else: | |
text = serialized | |
if isinstance(proto, onnx.ModelProto): | |
return onnx.parser.parse_model(text) # type: ignore[return-value] | |
if isinstance(proto, onnx.GraphProto): | |
return onnx.parser.parse_graph(text) # type: ignore[return-value] | |
if isinstance(proto, onnx.FunctionProto): | |
return onnx.parser.parse_function(text) # type: ignore[return-value] | |
if isinstance(proto, onnx.NodeProto): | |
return onnx.parser.parse_node(text) # type: ignore[return-value] | |
raise ValueError(f"Unsupported proto type: {type(proto)}") | |
# Register default serializers | |
registry = _Registry() | |
registry.register(_ProtobufSerializer()) | |
registry.register(_TextProtoSerializer()) | |
registry.register(_JsonSerializer()) | |
registry.register(_TextualSerializer()) | |