Spaces:
Runtime error
Runtime error
"""This module includes classes and functions designed specifically for use with the mypy plugin.""" | |
from __future__ import annotations | |
import sys | |
from configparser import ConfigParser | |
from typing import Any, Callable, Iterator | |
from mypy.errorcodes import ErrorCode | |
from mypy.expandtype import expand_type, expand_type_by_instance | |
from mypy.nodes import ( | |
ARG_NAMED, | |
ARG_NAMED_OPT, | |
ARG_OPT, | |
ARG_POS, | |
ARG_STAR2, | |
MDEF, | |
Argument, | |
AssignmentStmt, | |
Block, | |
CallExpr, | |
ClassDef, | |
Context, | |
Decorator, | |
DictExpr, | |
EllipsisExpr, | |
Expression, | |
FuncDef, | |
IfStmt, | |
JsonDict, | |
MemberExpr, | |
NameExpr, | |
PassStmt, | |
PlaceholderNode, | |
RefExpr, | |
Statement, | |
StrExpr, | |
SymbolTableNode, | |
TempNode, | |
TypeAlias, | |
TypeInfo, | |
Var, | |
) | |
from mypy.options import Options | |
from mypy.plugin import ( | |
CheckerPluginInterface, | |
ClassDefContext, | |
FunctionContext, | |
MethodContext, | |
Plugin, | |
ReportConfigContext, | |
SemanticAnalyzerPluginInterface, | |
) | |
from mypy.plugins import dataclasses | |
from mypy.plugins.common import ( | |
deserialize_and_fixup_type, | |
) | |
from mypy.semanal import set_callable_name | |
from mypy.server.trigger import make_wildcard_trigger | |
from mypy.state import state | |
from mypy.typeops import map_type_from_supertype | |
from mypy.types import ( | |
AnyType, | |
CallableType, | |
Instance, | |
NoneType, | |
Overloaded, | |
Type, | |
TypeOfAny, | |
TypeType, | |
TypeVarType, | |
UnionType, | |
get_proper_type, | |
) | |
from mypy.typevars import fill_typevars | |
from mypy.util import get_unique_redefinition_name | |
from mypy.version import __version__ as mypy_version | |
from pydantic._internal import _fields | |
from pydantic.version import parse_mypy_version | |
try: | |
from mypy.types import TypeVarDef # type: ignore[attr-defined] | |
except ImportError: # pragma: no cover | |
# Backward-compatible with TypeVarDef from Mypy 0.930. | |
from mypy.types import TypeVarType as TypeVarDef | |
CONFIGFILE_KEY = 'pydantic-mypy' | |
METADATA_KEY = 'pydantic-mypy-metadata' | |
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel' | |
BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings' | |
MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass' | |
FIELD_FULLNAME = 'pydantic.fields.Field' | |
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass' | |
MODEL_VALIDATOR_FULLNAME = 'pydantic.functional_validators.model_validator' | |
DECORATOR_FULLNAMES = { | |
'pydantic.functional_validators.field_validator', | |
'pydantic.functional_validators.model_validator', | |
'pydantic.functional_serializers.serializer', | |
'pydantic.functional_serializers.model_serializer', | |
'pydantic.deprecated.class_validators.validator', | |
'pydantic.deprecated.class_validators.root_validator', | |
} | |
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version) | |
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__' | |
# Increment version if plugin changes and mypy caches should be invalidated | |
__version__ = 2 | |
def plugin(version: str) -> type[Plugin]: | |
"""`version` is the mypy version string. | |
We might want to use this to print a warning if the mypy version being used is | |
newer, or especially older, than we expect (or need). | |
Args: | |
version: The mypy version string. | |
Return: | |
The Pydantic mypy plugin type. | |
""" | |
return PydanticPlugin | |
class PydanticPlugin(Plugin): | |
"""The Pydantic mypy plugin.""" | |
def __init__(self, options: Options) -> None: | |
self.plugin_config = PydanticPluginConfig(options) | |
self._plugin_data = self.plugin_config.to_data() | |
super().__init__(options) | |
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], bool] | None: | |
"""Update Pydantic model class.""" | |
sym = self.lookup_fully_qualified(fullname) | |
if sym and isinstance(sym.node, TypeInfo): # pragma: no branch | |
# No branching may occur if the mypy cache has not been cleared | |
if any(base.fullname == BASEMODEL_FULLNAME for base in sym.node.mro): | |
return self._pydantic_model_class_maker_callback | |
return None | |
def get_metaclass_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: | |
"""Update Pydantic `ModelMetaclass` definition.""" | |
if fullname == MODEL_METACLASS_FULLNAME: | |
return self._pydantic_model_metaclass_marker_callback | |
return None | |
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: | |
"""Adjust the return type of the `Field` function.""" | |
sym = self.lookup_fully_qualified(fullname) | |
if sym and sym.fullname == FIELD_FULLNAME: | |
return self._pydantic_field_callback | |
return None | |
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None: | |
"""Adjust return type of `from_orm` method call.""" | |
if fullname.endswith('.from_orm'): | |
return from_attributes_callback | |
return None | |
def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: | |
"""Mark pydantic.dataclasses as dataclass. | |
Mypy version 1.1.1 added support for `@dataclass_transform` decorator. | |
""" | |
if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1): | |
return dataclasses.dataclass_class_maker_callback # type: ignore[return-value] | |
return None | |
def report_config_data(self, ctx: ReportConfigContext) -> dict[str, Any]: | |
"""Return all plugin config data. | |
Used by mypy to determine if cache needs to be discarded. | |
""" | |
return self._plugin_data | |
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool: | |
transformer = PydanticModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config) | |
return transformer.transform() | |
def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None: | |
"""Reset dataclass_transform_spec attribute of ModelMetaclass. | |
Let the plugin handle it. This behavior can be disabled | |
if 'debug_dataclass_transform' is set to True', for testing purposes. | |
""" | |
if self.plugin_config.debug_dataclass_transform: | |
return | |
info_metaclass = ctx.cls.info.declared_metaclass | |
assert info_metaclass, "callback not passed from 'get_metaclass_hook'" | |
if getattr(info_metaclass.type, 'dataclass_transform_spec', None): | |
info_metaclass.type.dataclass_transform_spec = None | |
def _pydantic_field_callback(self, ctx: FunctionContext) -> Type: | |
"""Extract the type of the `default` argument from the Field function, and use it as the return type. | |
In particular: | |
* Check whether the default and default_factory argument is specified. | |
* Output an error if both are specified. | |
* Retrieve the type of the argument which is specified, and use it as return type for the function. | |
""" | |
default_any_type = ctx.default_return_type | |
assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()' | |
assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()' | |
default_args = ctx.args[0] | |
default_factory_args = ctx.args[1] | |
if default_args and default_factory_args: | |
error_default_and_default_factory_specified(ctx.api, ctx.context) | |
return default_any_type | |
if default_args: | |
default_type = ctx.arg_types[0][0] | |
default_arg = default_args[0] | |
# Fallback to default Any type if the field is required | |
if not isinstance(default_arg, EllipsisExpr): | |
return default_type | |
elif default_factory_args: | |
default_factory_type = ctx.arg_types[1][0] | |
# Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter | |
# Pydantic calls the default factory without any argument, so we retrieve the first item | |
if isinstance(default_factory_type, Overloaded): | |
default_factory_type = default_factory_type.items[0] | |
if isinstance(default_factory_type, CallableType): | |
ret_type = default_factory_type.ret_type | |
# mypy doesn't think `ret_type` has `args`, you'd think mypy should know, | |
# add this check in case it varies by version | |
args = getattr(ret_type, 'args', None) | |
if args: | |
if all(isinstance(arg, TypeVarType) for arg in args): | |
# Looks like the default factory is a type like `list` or `dict`, replace all args with `Any` | |
ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined] | |
return ret_type | |
return default_any_type | |
class PydanticPluginConfig: | |
"""A Pydantic mypy plugin config holder. | |
Attributes: | |
init_forbid_extra: Whether to add a `**kwargs` at the end of the generated `__init__` signature. | |
init_typed: Whether to annotate fields in the generated `__init__`. | |
warn_required_dynamic_aliases: Whether to raise required dynamic aliases error. | |
debug_dataclass_transform: Whether to not reset `dataclass_transform_spec` attribute | |
of `ModelMetaclass` for testing purposes. | |
""" | |
__slots__ = ( | |
'init_forbid_extra', | |
'init_typed', | |
'warn_required_dynamic_aliases', | |
'debug_dataclass_transform', | |
) | |
init_forbid_extra: bool | |
init_typed: bool | |
warn_required_dynamic_aliases: bool | |
debug_dataclass_transform: bool # undocumented | |
def __init__(self, options: Options) -> None: | |
if options.config_file is None: # pragma: no cover | |
return | |
toml_config = parse_toml(options.config_file) | |
if toml_config is not None: | |
config = toml_config.get('tool', {}).get('pydantic-mypy', {}) | |
for key in self.__slots__: | |
setting = config.get(key, False) | |
if not isinstance(setting, bool): | |
raise ValueError(f'Configuration value must be a boolean for key: {key}') | |
setattr(self, key, setting) | |
else: | |
plugin_config = ConfigParser() | |
plugin_config.read(options.config_file) | |
for key in self.__slots__: | |
setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False) | |
setattr(self, key, setting) | |
def to_data(self) -> dict[str, Any]: | |
"""Returns a dict of config names to their values.""" | |
return {key: getattr(self, key) for key in self.__slots__} | |
def from_attributes_callback(ctx: MethodContext) -> Type: | |
"""Raise an error if from_attributes is not enabled.""" | |
model_type: Instance | |
ctx_type = ctx.type | |
if isinstance(ctx_type, TypeType): | |
ctx_type = ctx_type.item | |
if isinstance(ctx_type, CallableType) and isinstance(ctx_type.ret_type, Instance): | |
model_type = ctx_type.ret_type # called on the class | |
elif isinstance(ctx_type, Instance): | |
model_type = ctx_type # called on an instance (unusual, but still valid) | |
else: # pragma: no cover | |
detail = f'ctx.type: {ctx_type} (of type {ctx_type.__class__.__name__})' | |
error_unexpected_behavior(detail, ctx.api, ctx.context) | |
return ctx.default_return_type | |
pydantic_metadata = model_type.type.metadata.get(METADATA_KEY) | |
if pydantic_metadata is None: | |
return ctx.default_return_type | |
from_attributes = pydantic_metadata.get('config', {}).get('from_attributes') | |
if from_attributes is not True: | |
error_from_attributes(model_type.type.name, ctx.api, ctx.context) | |
return ctx.default_return_type | |
class PydanticModelField: | |
"""Based on mypy.plugins.dataclasses.DataclassAttribute.""" | |
def __init__( | |
self, | |
name: str, | |
alias: str | None, | |
has_dynamic_alias: bool, | |
has_default: bool, | |
line: int, | |
column: int, | |
type: Type | None, | |
info: TypeInfo, | |
): | |
self.name = name | |
self.alias = alias | |
self.has_dynamic_alias = has_dynamic_alias | |
self.has_default = has_default | |
self.line = line | |
self.column = column | |
self.type = type | |
self.info = info | |
def to_argument(self, current_info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument: | |
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_argument.""" | |
return Argument( | |
variable=self.to_var(current_info, use_alias), | |
type_annotation=self.expand_type(current_info) if typed else AnyType(TypeOfAny.explicit), | |
initializer=None, | |
kind=ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED, | |
) | |
def expand_type(self, current_info: TypeInfo) -> Type | None: | |
"""Based on mypy.plugins.dataclasses.DataclassAttribute.expand_type.""" | |
if self.type is not None and self.info.self_type is not None: | |
# In general, it is not safe to call `expand_type()` during semantic analyzis, | |
# however this plugin is called very late, so all types should be fully ready. | |
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because | |
# we serialize attributes). | |
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)}) | |
return self.type | |
def to_var(self, current_info: TypeInfo, use_alias: bool) -> Var: | |
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_var.""" | |
if use_alias and self.alias is not None: | |
name = self.alias | |
else: | |
name = self.name | |
return Var(name, self.expand_type(current_info)) | |
def serialize(self) -> JsonDict: | |
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize.""" | |
assert self.type | |
return { | |
'name': self.name, | |
'alias': self.alias, | |
'has_dynamic_alias': self.has_dynamic_alias, | |
'has_default': self.has_default, | |
'line': self.line, | |
'column': self.column, | |
'type': self.type.serialize(), | |
} | |
def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface) -> PydanticModelField: | |
"""Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize.""" | |
data = data.copy() | |
typ = deserialize_and_fixup_type(data.pop('type'), api) | |
return cls(type=typ, info=info, **data) | |
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: | |
"""Expands type vars in the context of a subtype when an attribute is inherited | |
from a generic super type. | |
""" | |
if self.type is not None: | |
self.type = map_type_from_supertype(self.type, sub_type, self.info) | |
class PydanticModelTransformer: | |
"""Transform the BaseModel subclass according to the plugin settings. | |
Attributes: | |
tracked_config_fields: A set of field configs that the plugin has to track their value. | |
""" | |
tracked_config_fields: set[str] = { | |
'extra', | |
'frozen', | |
'from_attributes', | |
'populate_by_name', | |
'alias_generator', | |
} | |
def __init__( | |
self, | |
cls: ClassDef, | |
reason: Expression | Statement, | |
api: SemanticAnalyzerPluginInterface, | |
plugin_config: PydanticPluginConfig, | |
) -> None: | |
self._cls = cls | |
self._reason = reason | |
self._api = api | |
self.plugin_config = plugin_config | |
def transform(self) -> bool: | |
"""Configures the BaseModel subclass according to the plugin settings. | |
In particular: | |
* determines the model config and fields, | |
* adds a fields-aware signature for the initializer and construct methods | |
* freezes the class if frozen = True | |
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses | |
""" | |
info = self._cls.info | |
config = self.collect_config() | |
fields = self.collect_fields(config) | |
if fields is None: | |
# Some definitions are not ready. We need another pass. | |
return False | |
for field in fields: | |
if field.type is None: | |
return False | |
is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1]) | |
self.add_initializer(fields, config, is_settings) | |
self.add_model_construct_method(fields, config, is_settings) | |
self.set_frozen(fields, frozen=config.frozen is True) | |
self.adjust_decorator_signatures() | |
info.metadata[METADATA_KEY] = { | |
'fields': {field.name: field.serialize() for field in fields}, | |
'config': config.get_values_dict(), | |
} | |
return True | |
def adjust_decorator_signatures(self) -> None: | |
"""When we decorate a function `f` with `pydantic.validator(...)`, `pydantic.field_validator` | |
or `pydantic.serializer(...)`, mypy sees `f` as a regular method taking a `self` instance, | |
even though pydantic internally wraps `f` with `classmethod` if necessary. | |
Teach mypy this by marking any function whose outermost decorator is a `validator()`, | |
`field_validator()` or `serializer()` call as a `classmethod`. | |
""" | |
for name, sym in self._cls.info.names.items(): | |
if isinstance(sym.node, Decorator): | |
first_dec = sym.node.original_decorators[0] | |
if ( | |
isinstance(first_dec, CallExpr) | |
and isinstance(first_dec.callee, NameExpr) | |
and first_dec.callee.fullname in DECORATOR_FULLNAMES | |
# @model_validator(mode="after") is an exception, it expects a regular method | |
and not ( | |
first_dec.callee.fullname == MODEL_VALIDATOR_FULLNAME | |
and any( | |
first_dec.arg_names[i] == 'mode' and isinstance(arg, StrExpr) and arg.value == 'after' | |
for i, arg in enumerate(first_dec.args) | |
) | |
) | |
): | |
# TODO: Only do this if the first argument of the decorated function is `cls` | |
sym.node.func.is_class = True | |
def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity) | |
"""Collects the values of the config attributes that are used by the plugin, accounting for parent classes.""" | |
cls = self._cls | |
config = ModelConfigData() | |
has_config_kwargs = False | |
has_config_from_namespace = False | |
# Handle `class MyModel(BaseModel, <name>=<expr>, ...):` | |
for name, expr in cls.keywords.items(): | |
config_data = self.get_config_update(name, expr) | |
if config_data: | |
has_config_kwargs = True | |
config.update(config_data) | |
# Handle `model_config` | |
stmt: Statement | None = None | |
for stmt in cls.defs.body: | |
if not isinstance(stmt, (AssignmentStmt, ClassDef)): | |
continue | |
if isinstance(stmt, AssignmentStmt): | |
lhs = stmt.lvalues[0] | |
if not isinstance(lhs, NameExpr) or lhs.name != 'model_config': | |
continue | |
if isinstance(stmt.rvalue, CallExpr): # calls to `dict` or `ConfigDict` | |
for arg_name, arg in zip(stmt.rvalue.arg_names, stmt.rvalue.args): | |
if arg_name is None: | |
continue | |
config.update(self.get_config_update(arg_name, arg)) | |
elif isinstance(stmt.rvalue, DictExpr): # dict literals | |
for key_expr, value_expr in stmt.rvalue.items: | |
if not isinstance(key_expr, StrExpr): | |
continue | |
config.update(self.get_config_update(key_expr.value, value_expr)) | |
elif isinstance(stmt, ClassDef): | |
if stmt.name != 'Config': # 'deprecated' Config-class | |
continue | |
for substmt in stmt.defs.body: | |
if not isinstance(substmt, AssignmentStmt): | |
continue | |
lhs = substmt.lvalues[0] | |
if not isinstance(lhs, NameExpr): | |
continue | |
config.update(self.get_config_update(lhs.name, substmt.rvalue)) | |
if has_config_kwargs: | |
self._api.fail( | |
'Specifying config in two places is ambiguous, use either Config attribute or class kwargs', | |
cls, | |
) | |
break | |
has_config_from_namespace = True | |
if has_config_kwargs or has_config_from_namespace: | |
if ( | |
stmt | |
and config.has_alias_generator | |
and not config.populate_by_name | |
and self.plugin_config.warn_required_dynamic_aliases | |
): | |
error_required_dynamic_aliases(self._api, stmt) | |
for info in cls.info.mro[1:]: # 0 is the current class | |
if METADATA_KEY not in info.metadata: | |
continue | |
# Each class depends on the set of fields in its ancestors | |
self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) | |
for name, value in info.metadata[METADATA_KEY]['config'].items(): | |
config.setdefault(name, value) | |
return config | |
def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelField] | None: | |
"""Collects the fields for the model, accounting for parent classes.""" | |
cls = self._cls | |
# First, collect fields belonging to any class in the MRO, ignoring duplicates. | |
# | |
# We iterate through the MRO in reverse because attrs defined in the parent must appear | |
# earlier in the attributes list than attrs defined in the child. See: | |
# https://docs.python.org/3/library/dataclasses.html#inheritance | |
# | |
# However, we also want fields defined in the subtype to override ones defined | |
# in the parent. We can implement this via a dict without disrupting the attr order | |
# because dicts preserve insertion order in Python 3.7+. | |
found_fields: dict[str, PydanticModelField] = {} | |
for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object | |
# if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata: | |
# # We haven't processed the base class yet. Need another pass. | |
# return None | |
if METADATA_KEY not in info.metadata: | |
continue | |
# Each class depends on the set of attributes in its dataclass ancestors. | |
self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) | |
for name, data in info.metadata[METADATA_KEY]['fields'].items(): | |
field = PydanticModelField.deserialize(info, data, self._api) | |
# (The following comment comes directly from the dataclasses plugin) | |
# TODO: We shouldn't be performing type operations during the main | |
# semantic analysis pass, since some TypeInfo attributes might | |
# still be in flux. This should be performed in a later phase. | |
with state.strict_optional_set(self._api.options.strict_optional): | |
field.expand_typevar_from_subtype(cls.info) | |
found_fields[name] = field | |
sym_node = cls.info.names.get(name) | |
if sym_node and sym_node.node and not isinstance(sym_node.node, Var): | |
self._api.fail( | |
'BaseModel field may only be overridden by another field', | |
sym_node.node, | |
) | |
# Second, collect fields belonging to the current class. | |
current_field_names: set[str] = set() | |
for stmt in self._get_assignment_statements_from_block(cls.defs): | |
maybe_field = self.collect_field_from_stmt(stmt, model_config) | |
if maybe_field is not None: | |
lhs = stmt.lvalues[0] | |
current_field_names.add(lhs.name) | |
found_fields[lhs.name] = maybe_field | |
return list(found_fields.values()) | |
def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]: | |
for body in stmt.body: | |
if not body.is_unreachable: | |
yield from self._get_assignment_statements_from_block(body) | |
if stmt.else_body is not None and not stmt.else_body.is_unreachable: | |
yield from self._get_assignment_statements_from_block(stmt.else_body) | |
def _get_assignment_statements_from_block(self, block: Block) -> Iterator[AssignmentStmt]: | |
for stmt in block.body: | |
if isinstance(stmt, AssignmentStmt): | |
yield stmt | |
elif isinstance(stmt, IfStmt): | |
yield from self._get_assignment_statements_from_if_statement(stmt) | |
def collect_field_from_stmt( # noqa C901 | |
self, stmt: AssignmentStmt, model_config: ModelConfigData | |
) -> PydanticModelField | None: | |
"""Get pydantic model field from statement. | |
Args: | |
stmt: The statement. | |
model_config: Configuration settings for the model. | |
Returns: | |
A pydantic model field if it could find the field in statement. Otherwise, `None`. | |
""" | |
cls = self._cls | |
lhs = stmt.lvalues[0] | |
if not isinstance(lhs, NameExpr) or not _fields.is_valid_field_name(lhs.name) or lhs.name == 'model_config': | |
return None | |
if not stmt.new_syntax: | |
if ( | |
isinstance(stmt.rvalue, CallExpr) | |
and isinstance(stmt.rvalue.callee, CallExpr) | |
and isinstance(stmt.rvalue.callee.callee, NameExpr) | |
and stmt.rvalue.callee.callee.fullname in DECORATOR_FULLNAMES | |
): | |
# This is a (possibly-reused) validator or serializer, not a field | |
# In particular, it looks something like: my_validator = validator('my_field')(f) | |
# Eventually, we may want to attempt to respect model_config['ignored_types'] | |
return None | |
# The assignment does not have an annotation, and it's not anything else we recognize | |
error_untyped_fields(self._api, stmt) | |
return None | |
lhs = stmt.lvalues[0] | |
if not isinstance(lhs, NameExpr): | |
return None | |
if not _fields.is_valid_field_name(lhs.name) or lhs.name == 'model_config': | |
return None | |
sym = cls.info.names.get(lhs.name) | |
if sym is None: # pragma: no cover | |
# This is likely due to a star import (see the dataclasses plugin for a more detailed explanation) | |
# This is the same logic used in the dataclasses plugin | |
return None | |
node = sym.node | |
if isinstance(node, PlaceholderNode): # pragma: no cover | |
# See the PlaceholderNode docstring for more detail about how this can occur | |
# Basically, it is an edge case when dealing with complex import logic | |
# The dataclasses plugin now asserts this cannot happen, but I'd rather not error if it does.. | |
return None | |
if isinstance(node, TypeAlias): | |
self._api.fail( | |
'Type aliases inside BaseModel definitions are not supported at runtime', | |
node, | |
) | |
# Skip processing this node. This doesn't match the runtime behaviour, | |
# but the only alternative would be to modify the SymbolTable, | |
# and it's a little hairy to do that in a plugin. | |
return None | |
if not isinstance(node, Var): # pragma: no cover | |
# Don't know if this edge case still happens with the `is_valid_field` check above | |
# but better safe than sorry | |
# The dataclasses plugin now asserts this cannot happen, but I'd rather not error if it does.. | |
return None | |
# x: ClassVar[int] is not a field | |
if node.is_classvar: | |
return None | |
# x: InitVar[int] is not supported in BaseModel | |
node_type = get_proper_type(node.type) | |
if isinstance(node_type, Instance) and node_type.type.fullname == 'dataclasses.InitVar': | |
self._api.fail( | |
'InitVar is not supported in BaseModel', | |
node, | |
) | |
has_default = self.get_has_default(stmt) | |
if sym.type is None and node.is_final and node.is_inferred: | |
# This follows the logic from the dataclasses plugin. The following comment is taken verbatim: | |
# | |
# This is a special case, assignment like x: Final = 42 is classified | |
# annotated above, but mypy strips the `Final` turning it into x = 42. | |
# We do not support inferred types in dataclasses, so we can try inferring | |
# type for simple literals, and otherwise require an explicit type | |
# argument for Final[...]. | |
typ = self._api.analyze_simple_literal_type(stmt.rvalue, is_final=True) | |
if typ: | |
node.type = typ | |
else: | |
self._api.fail( | |
'Need type argument for Final[...] with non-literal default in BaseModel', | |
stmt, | |
) | |
node.type = AnyType(TypeOfAny.from_error) | |
alias, has_dynamic_alias = self.get_alias_info(stmt) | |
if has_dynamic_alias and not model_config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases: | |
error_required_dynamic_aliases(self._api, stmt) | |
init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt) | |
return PydanticModelField( | |
name=lhs.name, | |
has_dynamic_alias=has_dynamic_alias, | |
has_default=has_default, | |
alias=alias, | |
line=stmt.line, | |
column=stmt.column, | |
type=init_type, | |
info=cls.info, | |
) | |
def _infer_dataclass_attr_init_type(self, sym: SymbolTableNode, name: str, context: Context) -> Type | None: | |
"""Infer __init__ argument type for an attribute. | |
In particular, possibly use the signature of __set__. | |
""" | |
default = sym.type | |
if sym.implicit: | |
return default | |
t = get_proper_type(sym.type) | |
# Perform a simple-minded inference from the signature of __set__, if present. | |
# We can't use mypy.checkmember here, since this plugin runs before type checking. | |
# We only support some basic scanerios here, which is hopefully sufficient for | |
# the vast majority of use cases. | |
if not isinstance(t, Instance): | |
return default | |
setter = t.type.get('__set__') | |
if setter: | |
if isinstance(setter.node, FuncDef): | |
super_info = t.type.get_containing_type_info('__set__') | |
assert super_info | |
if setter.type: | |
setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info)) | |
else: | |
return AnyType(TypeOfAny.unannotated) | |
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [ | |
ARG_POS, | |
ARG_POS, | |
ARG_POS, | |
]: | |
return expand_type_by_instance(setter_type.arg_types[2], t) | |
else: | |
self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context) | |
else: | |
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context) | |
return default | |
def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool) -> None: | |
"""Adds a fields-aware `__init__` method to the class. | |
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings. | |
""" | |
if '__init__' in self._cls.info.names and not self._cls.info.names['__init__'].plugin_generated: | |
return # Don't generate an __init__ if one already exists | |
typed = self.plugin_config.init_typed | |
use_alias = config.populate_by_name is not True | |
requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name) | |
with state.strict_optional_set(self._api.options.strict_optional): | |
args = self.get_field_arguments( | |
fields, | |
typed=typed, | |
requires_dynamic_aliases=requires_dynamic_aliases, | |
use_alias=use_alias, | |
is_settings=is_settings, | |
) | |
if is_settings: | |
base_settings_info = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node.defn.info | |
base_settings_init_arguments = base_settings_info.names['__init__'].node.arguments | |
settings_init_arguments = [] | |
a: Argument | |
for a in base_settings_init_arguments: | |
if a.variable.name.startswith('__') or not a.variable.name.startswith('_'): | |
continue | |
analyzed_variable_type = self._api.anal_type(a.variable.type) | |
variable = Var(a.variable.name, analyzed_variable_type) | |
settings_init_arguments.append(Argument(variable, analyzed_variable_type, None, ARG_OPT)) | |
args.extend(settings_init_arguments) | |
if not self.should_init_forbid_extra(fields, config): | |
var = Var('kwargs') | |
args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) | |
add_method(self._api, self._cls, '__init__', args=args, return_type=NoneType()) | |
def add_model_construct_method( | |
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool | |
) -> None: | |
"""Adds a fully typed `model_construct` classmethod to the class. | |
Similar to the fields-aware __init__ method, but always uses the field names (not aliases), | |
and does not treat settings fields as optional. | |
""" | |
set_str = self._api.named_type(f'{BUILTINS_NAME}.set', [self._api.named_type(f'{BUILTINS_NAME}.str')]) | |
optional_set_str = UnionType([set_str, NoneType()]) | |
fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT) | |
with state.strict_optional_set(self._api.options.strict_optional): | |
args = self.get_field_arguments( | |
fields, typed=True, requires_dynamic_aliases=False, use_alias=False, is_settings=is_settings | |
) | |
if not self.should_init_forbid_extra(fields, config): | |
var = Var('kwargs') | |
args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) | |
args = [fields_set_argument] + args | |
add_method( | |
self._api, | |
self._cls, | |
'model_construct', | |
args=args, | |
return_type=fill_typevars(self._cls.info), | |
is_classmethod=True, | |
) | |
def set_frozen(self, fields: list[PydanticModelField], frozen: bool) -> None: | |
"""Marks all fields as properties so that attempts to set them trigger mypy errors. | |
This is the same approach used by the attrs and dataclasses plugins. | |
""" | |
info = self._cls.info | |
for field in fields: | |
sym_node = info.names.get(field.name) | |
if sym_node is not None: | |
var = sym_node.node | |
if isinstance(var, Var): | |
var.is_property = frozen | |
elif isinstance(var, PlaceholderNode) and not self._api.final_iteration: | |
# See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage | |
self._api.defer() | |
else: # pragma: no cover | |
# I don't know whether it's possible to hit this branch, but I've added it for safety | |
try: | |
var_str = str(var) | |
except TypeError: | |
# This happens for PlaceholderNode; perhaps it will happen for other types in the future.. | |
var_str = repr(var) | |
detail = f'sym_node.node: {var_str} (of type {var.__class__})' | |
error_unexpected_behavior(detail, self._api, self._cls) | |
else: | |
var = field.to_var(info, use_alias=False) | |
var.info = info | |
var.is_property = frozen | |
var._fullname = info.fullname + '.' + var.name | |
info.names[var.name] = SymbolTableNode(MDEF, var) | |
def get_config_update(self, name: str, arg: Expression) -> ModelConfigData | None: | |
"""Determines the config update due to a single kwarg in the ConfigDict definition. | |
Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int) | |
""" | |
if name not in self.tracked_config_fields: | |
return None | |
if name == 'extra': | |
if isinstance(arg, StrExpr): | |
forbid_extra = arg.value == 'forbid' | |
elif isinstance(arg, MemberExpr): | |
forbid_extra = arg.name == 'forbid' | |
else: | |
error_invalid_config_value(name, self._api, arg) | |
return None | |
return ModelConfigData(forbid_extra=forbid_extra) | |
if name == 'alias_generator': | |
has_alias_generator = True | |
if isinstance(arg, NameExpr) and arg.fullname == 'builtins.None': | |
has_alias_generator = False | |
return ModelConfigData(has_alias_generator=has_alias_generator) | |
if isinstance(arg, NameExpr) and arg.fullname in ('builtins.True', 'builtins.False'): | |
return ModelConfigData(**{name: arg.fullname == 'builtins.True'}) | |
error_invalid_config_value(name, self._api, arg) | |
return None | |
def get_has_default(stmt: AssignmentStmt) -> bool: | |
"""Returns a boolean indicating whether the field defined in `stmt` is a required field.""" | |
expr = stmt.rvalue | |
if isinstance(expr, TempNode): | |
# TempNode means annotation-only, so has no default | |
return False | |
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME: | |
# The "default value" is a call to `Field`; at this point, the field has a default if and only if: | |
# * there is a positional argument that is not `...` | |
# * there is a keyword argument named "default" that is not `...` | |
# * there is a "default_factory" that is not `None` | |
for arg, name in zip(expr.args, expr.arg_names): | |
# If name is None, then this arg is the default because it is the only positional argument. | |
if name is None or name == 'default': | |
return arg.__class__ is not EllipsisExpr | |
if name == 'default_factory': | |
return not (isinstance(arg, NameExpr) and arg.fullname == 'builtins.None') | |
return False | |
# Has no default if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`) | |
return not isinstance(expr, EllipsisExpr) | |
def get_alias_info(stmt: AssignmentStmt) -> tuple[str | None, bool]: | |
"""Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`. | |
`has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal. | |
If `has_dynamic_alias` is True, `alias` will be None. | |
""" | |
expr = stmt.rvalue | |
if isinstance(expr, TempNode): | |
# TempNode means annotation-only | |
return None, False | |
if not ( | |
isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME | |
): | |
# Assigned value is not a call to pydantic.fields.Field | |
return None, False | |
for i, arg_name in enumerate(expr.arg_names): | |
if arg_name != 'alias': | |
continue | |
arg = expr.args[i] | |
if isinstance(arg, StrExpr): | |
return arg.value, False | |
else: | |
return None, True | |
return None, False | |
def get_field_arguments( | |
self, | |
fields: list[PydanticModelField], | |
typed: bool, | |
use_alias: bool, | |
requires_dynamic_aliases: bool, | |
is_settings: bool, | |
) -> list[Argument]: | |
"""Helper function used during the construction of the `__init__` and `model_construct` method signatures. | |
Returns a list of mypy Argument instances for use in the generated signatures. | |
""" | |
info = self._cls.info | |
arguments = [ | |
field.to_argument( | |
info, typed=typed, force_optional=requires_dynamic_aliases or is_settings, use_alias=use_alias | |
) | |
for field in fields | |
if not (use_alias and field.has_dynamic_alias) | |
] | |
return arguments | |
def should_init_forbid_extra(self, fields: list[PydanticModelField], config: ModelConfigData) -> bool: | |
"""Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature. | |
We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to, | |
*unless* a required dynamic alias is present (since then we can't determine a valid signature). | |
""" | |
if not config.populate_by_name: | |
if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)): | |
return False | |
if config.forbid_extra: | |
return True | |
return self.plugin_config.init_forbid_extra | |
def is_dynamic_alias_present(fields: list[PydanticModelField], has_alias_generator: bool) -> bool: | |
"""Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be | |
determined during static analysis. | |
""" | |
for field in fields: | |
if field.has_dynamic_alias: | |
return True | |
if has_alias_generator: | |
for field in fields: | |
if field.alias is None: | |
return True | |
return False | |
class ModelConfigData: | |
"""Pydantic mypy plugin model config class.""" | |
def __init__( | |
self, | |
forbid_extra: bool | None = None, | |
frozen: bool | None = None, | |
from_attributes: bool | None = None, | |
populate_by_name: bool | None = None, | |
has_alias_generator: bool | None = None, | |
): | |
self.forbid_extra = forbid_extra | |
self.frozen = frozen | |
self.from_attributes = from_attributes | |
self.populate_by_name = populate_by_name | |
self.has_alias_generator = has_alias_generator | |
def get_values_dict(self) -> dict[str, Any]: | |
"""Returns a dict of Pydantic model config names to their values. | |
It includes the config if config value is not `None`. | |
""" | |
return {k: v for k, v in self.__dict__.items() if v is not None} | |
def update(self, config: ModelConfigData | None) -> None: | |
"""Update Pydantic model config values.""" | |
if config is None: | |
return | |
for k, v in config.get_values_dict().items(): | |
setattr(self, k, v) | |
def setdefault(self, key: str, value: Any) -> None: | |
"""Set default value for Pydantic model config if config value is `None`.""" | |
if getattr(self, key) is None: | |
setattr(self, key, value) | |
ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_attributes call', 'Pydantic') | |
ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic') | |
ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic') | |
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic') | |
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic') | |
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic') | |
def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None: | |
"""Emits an error when the model does not have `from_attributes=True`.""" | |
api.fail(f'"{model_name}" does not have from_attributes=True', context, code=ERROR_ORM) | |
def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None: | |
"""Emits an error when the config value is invalid.""" | |
api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG) | |
def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None: | |
"""Emits required dynamic aliases error. | |
This will be called when `warn_required_dynamic_aliases=True`. | |
""" | |
api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS) | |
def error_unexpected_behavior( | |
detail: str, api: CheckerPluginInterface | SemanticAnalyzerPluginInterface, context: Context | |
) -> None: # pragma: no cover | |
"""Emits unexpected behavior error.""" | |
# Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path | |
link = 'https://github.com/pydantic/pydantic/issues/new/choose' | |
full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n' | |
full_message += f'Please consider reporting this bug at {link} so we can try to fix it!' | |
api.fail(full_message, context, code=ERROR_UNEXPECTED) | |
def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None: | |
"""Emits an error when there is an untyped field in the model.""" | |
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED) | |
def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None: | |
"""Emits an error when `Field` has both `default` and `default_factory` together.""" | |
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS) | |
def add_method( | |
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, | |
cls: ClassDef, | |
name: str, | |
args: list[Argument], | |
return_type: Type, | |
self_type: Type | None = None, | |
tvar_def: TypeVarDef | None = None, | |
is_classmethod: bool = False, | |
) -> None: | |
"""Very closely related to `mypy.plugins.common.add_method_to_class`, with a few pydantic-specific changes.""" | |
info = cls.info | |
# First remove any previously generated methods with the same name | |
# to avoid clashes and problems in the semantic analyzer. | |
if name in info.names: | |
sym = info.names[name] | |
if sym.plugin_generated and isinstance(sym.node, FuncDef): | |
cls.defs.body.remove(sym.node) # pragma: no cover | |
if isinstance(api, SemanticAnalyzerPluginInterface): | |
function_type = api.named_type('builtins.function') | |
else: | |
function_type = api.named_generic_type('builtins.function', []) | |
if is_classmethod: | |
self_type = self_type or TypeType(fill_typevars(info)) | |
first = [Argument(Var('_cls'), self_type, None, ARG_POS, True)] | |
else: | |
self_type = self_type or fill_typevars(info) | |
first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)] | |
args = first + args | |
arg_types, arg_names, arg_kinds = [], [], [] | |
for arg in args: | |
assert arg.type_annotation, 'All arguments must be fully typed.' | |
arg_types.append(arg.type_annotation) | |
arg_names.append(arg.variable.name) | |
arg_kinds.append(arg.kind) | |
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) | |
if tvar_def: | |
signature.variables = [tvar_def] | |
func = FuncDef(name, args, Block([PassStmt()])) | |
func.info = info | |
func.type = set_callable_name(signature, func) | |
func.is_class = is_classmethod | |
func._fullname = info.fullname + '.' + name | |
func.line = info.line | |
# NOTE: we would like the plugin generated node to dominate, but we still | |
# need to keep any existing definitions so they get semantically analyzed. | |
if name in info.names: | |
# Get a nice unique name instead. | |
r_name = get_unique_redefinition_name(name, info.names) | |
info.names[r_name] = info.names[name] | |
# Add decorator for is_classmethod | |
# The dataclasses plugin claims this is unnecessary for classmethods, but not including it results in a | |
# signature incompatible with the superclass, which causes mypy errors to occur for every subclass of BaseModel. | |
if is_classmethod: | |
func.is_decorated = True | |
v = Var(name, func.type) | |
v.info = info | |
v._fullname = func._fullname | |
v.is_classmethod = True | |
dec = Decorator(func, [NameExpr('classmethod')], v) | |
dec.line = info.line | |
sym = SymbolTableNode(MDEF, dec) | |
else: | |
sym = SymbolTableNode(MDEF, func) | |
sym.plugin_generated = True | |
info.names[name] = sym | |
info.defn.defs.body.append(func) | |
def parse_toml(config_file: str) -> dict[str, Any] | None: | |
"""Returns a dict of config keys to values. | |
It reads configs from toml file and returns `None` if the file is not a toml file. | |
""" | |
if not config_file.endswith('.toml'): | |
return None | |
if sys.version_info >= (3, 11): | |
import tomllib as toml_ | |
else: | |
try: | |
import tomli as toml_ | |
except ImportError: # pragma: no cover | |
import warnings | |
warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.') | |
return None | |
with open(config_file, 'rb') as rf: | |
return toml_.load(rf) | |