|
|
|
|
|
import typing |
|
import warnings |
|
import sys |
|
from copy import deepcopy |
|
|
|
from dataclasses import MISSING, is_dataclass, fields as dc_fields |
|
from datetime import datetime |
|
from decimal import Decimal |
|
from uuid import UUID |
|
from enum import Enum |
|
|
|
from typing_inspect import is_union_type |
|
|
|
from marshmallow import fields, Schema, post_load |
|
from marshmallow.exceptions import ValidationError |
|
|
|
from dataclasses_json.core import (_is_supported_generic, _decode_dataclass, |
|
_ExtendedEncoder, _user_overrides_or_exts) |
|
from dataclasses_json.utils import (_is_collection, _is_optional, |
|
_issubclass_safe, _timestamp_to_dt_aware, |
|
_is_new_type, _get_type_origin, |
|
_handle_undefined_parameters_safe, |
|
CatchAllVar) |
|
|
|
|
|
class _TimestampField(fields.Field): |
|
def _serialize(self, value, attr, obj, **kwargs): |
|
if value is not None: |
|
return value.timestamp() |
|
else: |
|
if not self.required: |
|
return None |
|
else: |
|
raise ValidationError(self.default_error_messages["required"]) |
|
|
|
def _deserialize(self, value, attr, data, **kwargs): |
|
if value is not None: |
|
return _timestamp_to_dt_aware(value) |
|
else: |
|
if not self.required: |
|
return None |
|
else: |
|
raise ValidationError(self.default_error_messages["required"]) |
|
|
|
|
|
class _IsoField(fields.Field): |
|
def _serialize(self, value, attr, obj, **kwargs): |
|
if value is not None: |
|
return value.isoformat() |
|
else: |
|
if not self.required: |
|
return None |
|
else: |
|
raise ValidationError(self.default_error_messages["required"]) |
|
|
|
def _deserialize(self, value, attr, data, **kwargs): |
|
if value is not None: |
|
return datetime.fromisoformat(value) |
|
else: |
|
if not self.required: |
|
return None |
|
else: |
|
raise ValidationError(self.default_error_messages["required"]) |
|
|
|
|
|
class _UnionField(fields.Field): |
|
def __init__(self, desc, cls, field, *args, **kwargs): |
|
self.desc = desc |
|
self.cls = cls |
|
self.field = field |
|
super().__init__(*args, **kwargs) |
|
|
|
def _serialize(self, value, attr, obj, **kwargs): |
|
if self.allow_none and value is None: |
|
return None |
|
for type_, schema_ in self.desc.items(): |
|
if _issubclass_safe(type(value), type_): |
|
if is_dataclass(value): |
|
res = schema_._serialize(value, attr, obj, **kwargs) |
|
res['__type'] = str(type_.__name__) |
|
return res |
|
break |
|
elif isinstance(value, _get_type_origin(type_)): |
|
return schema_._serialize(value, attr, obj, **kwargs) |
|
else: |
|
warnings.warn( |
|
f'The type "{type(value).__name__}" (value: "{value}") ' |
|
f'is not in the list of possible types of typing.Union ' |
|
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' |
|
f'Value cannot be serialized properly.') |
|
return super()._serialize(value, attr, obj, **kwargs) |
|
|
|
def _deserialize(self, value, attr, data, **kwargs): |
|
tmp_value = deepcopy(value) |
|
if isinstance(tmp_value, dict) and '__type' in tmp_value: |
|
dc_name = tmp_value['__type'] |
|
for type_, schema_ in self.desc.items(): |
|
if is_dataclass(type_) and type_.__name__ == dc_name: |
|
del tmp_value['__type'] |
|
return schema_._deserialize(tmp_value, attr, data, **kwargs) |
|
elif isinstance(tmp_value, dict): |
|
warnings.warn( |
|
f'Attempting to deserialize "dict" (value: "{tmp_value}) ' |
|
f'that does not have a "__type" type specifier field into' |
|
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).' |
|
f'Deserialization may fail, or deserialization to wrong type may occur.' |
|
) |
|
return super()._deserialize(tmp_value, attr, data, **kwargs) |
|
else: |
|
for type_, schema_ in self.desc.items(): |
|
if isinstance(tmp_value, _get_type_origin(type_)): |
|
return schema_._deserialize(tmp_value, attr, data, **kwargs) |
|
else: |
|
warnings.warn( |
|
f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") ' |
|
f'is not in the list of possible types of typing.Union ' |
|
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' |
|
f'Value cannot be deserialized properly.') |
|
return super()._deserialize(tmp_value, attr, data, **kwargs) |
|
|
|
|
|
class _TupleVarLen(fields.List): |
|
""" |
|
variable-length homogeneous tuples |
|
""" |
|
def _deserialize(self, value, attr, data, **kwargs): |
|
optional_list = super()._deserialize(value, attr, data, **kwargs) |
|
return None if optional_list is None else tuple(optional_list) |
|
|
|
|
|
TYPES = { |
|
typing.Mapping: fields.Mapping, |
|
typing.MutableMapping: fields.Mapping, |
|
typing.List: fields.List, |
|
typing.Dict: fields.Dict, |
|
typing.Tuple: fields.Tuple, |
|
typing.Callable: fields.Function, |
|
typing.Any: fields.Raw, |
|
dict: fields.Dict, |
|
list: fields.List, |
|
tuple: fields.Tuple, |
|
str: fields.Str, |
|
int: fields.Int, |
|
float: fields.Float, |
|
bool: fields.Bool, |
|
datetime: _TimestampField, |
|
UUID: fields.UUID, |
|
Decimal: fields.Decimal, |
|
CatchAllVar: fields.Dict, |
|
} |
|
|
|
A = typing.TypeVar('A') |
|
JsonData = typing.Union[str, bytes, bytearray] |
|
TEncoded = typing.Dict[str, typing.Any] |
|
TOneOrMulti = typing.Union[typing.List[A], A] |
|
TOneOrMultiEncoded = typing.Union[typing.List[TEncoded], TEncoded] |
|
|
|
if sys.version_info >= (3, 7) or typing.TYPE_CHECKING: |
|
class SchemaF(Schema, typing.Generic[A]): |
|
"""Lift Schema into a type constructor""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
""" |
|
Raises exception because this class should not be inherited. |
|
This class is helper only. |
|
""" |
|
|
|
super().__init__(*args, **kwargs) |
|
raise NotImplementedError() |
|
|
|
@typing.overload |
|
def dump(self, obj: typing.List[A], many: typing.Optional[bool] = None) -> typing.List[TEncoded]: |
|
|
|
pass |
|
|
|
@typing.overload |
|
def dump(self, obj: A, many: typing.Optional[bool] = None) -> TEncoded: |
|
pass |
|
|
|
def dump(self, obj: TOneOrMulti, |
|
many: typing.Optional[bool] = None) -> TOneOrMultiEncoded: |
|
pass |
|
|
|
@typing.overload |
|
def dumps(self, obj: typing.List[A], many: typing.Optional[bool] = None, *args, |
|
**kwargs) -> str: |
|
pass |
|
|
|
@typing.overload |
|
def dumps(self, obj: A, many: typing.Optional[bool] = None, *args, **kwargs) -> str: |
|
pass |
|
|
|
def dumps(self, obj: TOneOrMulti, many: typing.Optional[bool] = None, *args, |
|
**kwargs) -> str: |
|
pass |
|
|
|
@typing.overload |
|
def load(self, data: typing.List[TEncoded], |
|
many: bool = True, partial: typing.Optional[bool] = None, |
|
unknown: typing.Optional[str] = None) -> \ |
|
typing.List[A]: |
|
|
|
pass |
|
|
|
@typing.overload |
|
def load(self, data: TEncoded, |
|
many: None = None, partial: typing.Optional[bool] = None, |
|
unknown: typing.Optional[str] = None) -> A: |
|
pass |
|
|
|
def load(self, data: TOneOrMultiEncoded, |
|
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, |
|
unknown: typing.Optional[str] = None) -> TOneOrMulti: |
|
pass |
|
|
|
@typing.overload |
|
def loads(self, json_data: JsonData, |
|
many: typing.Optional[bool] = True, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None, |
|
**kwargs) -> typing.List[A]: |
|
|
|
|
|
|
|
pass |
|
|
|
def loads(self, json_data: JsonData, |
|
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None, |
|
**kwargs) -> TOneOrMulti: |
|
pass |
|
|
|
|
|
SchemaType = SchemaF[A] |
|
else: |
|
SchemaType = Schema |
|
|
|
|
|
def build_type(type_, options, mixin, field, cls): |
|
def inner(type_, options): |
|
while True: |
|
if not _is_new_type(type_): |
|
break |
|
|
|
type_ = type_.__supertype__ |
|
|
|
if is_dataclass(type_): |
|
if _issubclass_safe(type_, mixin): |
|
options['field_many'] = bool( |
|
_is_supported_generic(field.type) and _is_collection( |
|
field.type)) |
|
return fields.Nested(type_.schema(), **options) |
|
else: |
|
warnings.warn(f"Nested dataclass field {field.name} of type " |
|
f"{field.type} detected in " |
|
f"{cls.__name__} that is not an instance of " |
|
f"dataclass_json. Did you mean to recursively " |
|
f"serialize this field? If so, make sure to " |
|
f"augment {type_} with either the " |
|
f"`dataclass_json` decorator or mixin.") |
|
return fields.Field(**options) |
|
|
|
origin = getattr(type_, '__origin__', type_) |
|
args = [inner(a, {}) for a in getattr(type_, '__args__', []) if |
|
a is not type(None)] |
|
|
|
if type_ == Ellipsis: |
|
return type_ |
|
|
|
if _is_optional(type_): |
|
options["allow_none"] = True |
|
if origin is tuple: |
|
if len(args) == 2 and args[1] == Ellipsis: |
|
return _TupleVarLen(args[0], **options) |
|
else: |
|
return fields.Tuple(args, **options) |
|
if origin in TYPES: |
|
return TYPES[origin](*args, **options) |
|
|
|
if _issubclass_safe(origin, Enum): |
|
return fields.Enum(enum=origin, by_value=True, *args, **options) |
|
|
|
if is_union_type(type_): |
|
union_types = [a for a in getattr(type_, '__args__', []) if |
|
a is not type(None)] |
|
union_desc = dict(zip(union_types, args)) |
|
return _UnionField(union_desc, cls, field, **options) |
|
|
|
warnings.warn( |
|
f"Unknown type {type_} at {cls.__name__}.{field.name}: {field.type} " |
|
f"It's advised to pass the correct marshmallow type to `mm_field`.") |
|
return fields.Field(**options) |
|
|
|
return inner(type_, options) |
|
|
|
|
|
def schema(cls, mixin, infer_missing): |
|
schema = {} |
|
overrides = _user_overrides_or_exts(cls) |
|
|
|
|
|
for field in dc_fields(cls): |
|
metadata = overrides[field.name] |
|
if metadata.mm_field is not None: |
|
schema[field.name] = metadata.mm_field |
|
else: |
|
type_ = field.type |
|
options: typing.Dict[str, typing.Any] = {} |
|
missing_key = 'missing' if infer_missing else 'default' |
|
if field.default is not MISSING: |
|
options[missing_key] = field.default |
|
elif field.default_factory is not MISSING: |
|
options[missing_key] = field.default_factory() |
|
else: |
|
options['required'] = True |
|
|
|
if options.get(missing_key, ...) is None: |
|
options['allow_none'] = True |
|
|
|
if _is_optional(type_): |
|
options.setdefault(missing_key, None) |
|
options['allow_none'] = True |
|
if len(type_.__args__) == 2: |
|
|
|
type_ = [tp for tp in type_.__args__ if tp is not type(None)][0] |
|
|
|
if metadata.letter_case is not None: |
|
options['data_key'] = metadata.letter_case(field.name) |
|
|
|
t = build_type(type_, options, mixin, field, cls) |
|
if field.metadata.get('dataclasses_json', {}).get('decoder'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t._deserialize = lambda v, *_a, **_kw: v |
|
|
|
|
|
if field.type != typing.Optional[CatchAllVar]: |
|
schema[field.name] = t |
|
|
|
return schema |
|
|
|
|
|
def build_schema(cls: typing.Type[A], |
|
mixin, |
|
infer_missing, |
|
partial) -> typing.Type["SchemaType[A]"]: |
|
Meta = type('Meta', |
|
(), |
|
{'fields': tuple(field.name for field in dc_fields(cls) |
|
if |
|
field.name != 'dataclass_json_config' and field.type != |
|
typing.Optional[CatchAllVar]), |
|
|
|
|
|
}) |
|
|
|
@post_load |
|
def make_instance(self, kvs, **kwargs): |
|
return _decode_dataclass(cls, kvs, partial) |
|
|
|
def dumps(self, *args, **kwargs): |
|
if 'cls' not in kwargs: |
|
kwargs['cls'] = _ExtendedEncoder |
|
|
|
return Schema.dumps(self, *args, **kwargs) |
|
|
|
def dump(self, obj, *, many=None): |
|
many = self.many if many is None else bool(many) |
|
dumped = Schema.dump(self, obj, many=many) |
|
|
|
|
|
|
|
|
|
|
|
if many: |
|
for i, _obj in enumerate(obj): |
|
dumped[i].update( |
|
_handle_undefined_parameters_safe(cls=_obj, kvs={}, |
|
usage="dump")) |
|
else: |
|
dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={}, |
|
usage="dump")) |
|
return dumped |
|
|
|
schema_ = schema(cls, mixin, infer_missing) |
|
DataClassSchema: typing.Type["SchemaType[A]"] = type( |
|
f'{cls.__name__.capitalize()}Schema', |
|
(Schema,), |
|
{'Meta': Meta, |
|
f'make_{cls.__name__.lower()}': make_instance, |
|
'dumps': dumps, |
|
'dump': dump, |
|
**schema_}) |
|
|
|
return DataClassSchema |
|
|