From 3bce0d5b14287b7ccea7ecd63f420e6f7fb8a9bc Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 8 Jan 2023 18:21:15 +0100 Subject: [PATCH 1/4] Use `__new__` instead of `__init__` in serializers --- rest_framework-stubs/serializers.pyi | 58 ++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/rest_framework-stubs/serializers.pyi b/rest_framework-stubs/serializers.pyi index d17daccc3..fc78e0036 100644 --- a/rest_framework-stubs/serializers.pyi +++ b/rest_framework-stubs/serializers.pyi @@ -1,5 +1,5 @@ from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping, Sequence -from typing import Any, Generic, NoReturn, TypeVar +from typing import Any, Generic, NoReturn, TypeVar, overload from _typeshed import Self from django.core.exceptions import ValidationError as DjangoValidationError @@ -81,13 +81,57 @@ class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): instance: _IN | None initial_data: Any _context: dict[str, Any] - def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: ... def __class_getitem__(cls, *args, **kwargs): ... - def __init__( - self, + # When both __init__ and __new__ are present, mypy will prefer __init__ + @overload + def __new__( + cls: type[Self], + instance: Iterable[_IN] | None = ..., + data: Any = ..., + partial: bool = ..., + many: Literal[True] = ..., + allow_empty: bool = ..., + context: dict[str, Any] = ..., + read_only: bool = ..., + write_only: bool = ..., + required: bool = ..., + default: Any = ..., + initial: Any = ..., + source: str = ..., + label: str = ..., + help_text: str = ..., + style: dict[str, Any] = ..., + error_messages: dict[str, str] = ..., + validators: Sequence[Validator[Any]] | None = ..., + allow_null: bool = ..., + ) -> ListSerializer[_IN]: ... + @overload + def __new__( + cls: type[Self], instance: _IN | None = ..., data: Any = ..., partial: bool = ..., + many: Literal[False] = ..., + allow_empty: bool = ..., + context: dict[str, Any] = ..., + read_only: bool = ..., + write_only: bool = ..., + required: bool = ..., + default: Any = ..., + initial: Any = ..., + source: str = ..., + label: str = ..., + help_text: str = ..., + style: dict[str, Any] = ..., + error_messages: dict[str, str] = ..., + validators: Sequence[Validator[Any]] | None = ..., + allow_null: bool = ..., + ) -> Self: ... + def __new__( + cls, + instance: _IN | Iterable[_IN] | None = ..., + data: Any = ..., + partial: bool = ..., many: bool = ..., allow_empty: bool = ..., context: dict[str, Any] = ..., @@ -103,7 +147,7 @@ class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): error_messages: dict[str, str] = ..., validators: Sequence[Validator[Any]] | None = ..., allow_null: bool = ..., - ): ... + ) -> ListSerializer[_IN] | Self: ... @classmethod def many_init(cls, *args: Any, **kwargs: Any) -> BaseSerializer: ... def is_valid(self, raise_exception: bool = ...) -> bool: ... @@ -159,7 +203,7 @@ class ListSerializer( allow_empty: bool | None def __init__( self, - instance: _IN | None = ..., + instance: Iterable[_IN] | None = ..., data: Any = ..., partial: bool = ..., context: dict[str, Any] = ..., @@ -177,7 +221,7 @@ class ListSerializer( error_messages: dict[str, str] = ..., validators: Sequence[Validator[list[Any]]] | None = ..., allow_null: bool = ..., - ): ... + ) -> None: ... def get_initial(self) -> list[Mapping[Any, Any]]: ... def validate(self, attrs: Any) -> Any: ... @property From 5853faed4adf2094d4d17538385616079d85c0ba Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Mon, 9 Jan 2023 23:02:28 +0100 Subject: [PATCH 2/4] Use `__new__` in `ModelSerializer` --- rest_framework-stubs/serializers.pyi | 58 ++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/rest_framework-stubs/serializers.pyi b/rest_framework-stubs/serializers.pyi index fc78e0036..009d6982f 100644 --- a/rest_framework-stubs/serializers.pyi +++ b/rest_framework-stubs/serializers.pyi @@ -247,27 +247,71 @@ class ModelSerializer(Serializer, BaseSerializer[_MT]): exclude: Sequence[str] | None depth: int | None extra_kwargs: dict[str, dict[str, Any]] # type: ignore[override] - def __init__( - self, + @overload + def __new__( + cls: type[Self], instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT] = ..., data: Any = ..., partial: bool = ..., - many: bool = ..., + many: Literal[True] = ..., + allow_empty: bool = ..., context: dict[str, Any] = ..., read_only: bool = ..., write_only: bool = ..., required: bool = ..., - default: _MT | Sequence[_MT] | Callable[[], _MT | Sequence[_MT]] = ..., - initial: _MT | Sequence[_MT] | Callable[[], _MT | Sequence[_MT]] = ..., + default: Any = ..., + initial: Any = ..., source: str = ..., label: str = ..., help_text: str = ..., style: dict[str, Any] = ..., error_messages: dict[str, str] = ..., - validators: Sequence[Validator[_MT]] | None = ..., + validators: Sequence[Validator[Any]] | None = ..., allow_null: bool = ..., + ) -> ListSerializer[_IN]: ... + @overload + def __new__( + cls: type[Self], + instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT] = ..., + data: Any = ..., + partial: bool = ..., + many: Literal[False] = ..., allow_empty: bool = ..., - ): ... + context: dict[str, Any] = ..., + read_only: bool = ..., + write_only: bool = ..., + required: bool = ..., + default: Any = ..., + initial: Any = ..., + source: str = ..., + label: str = ..., + help_text: str = ..., + style: dict[str, Any] = ..., + error_messages: dict[str, str] = ..., + validators: Sequence[Validator[Any]] | None = ..., + allow_null: bool = ..., + ) -> Self: ... + def __new__( + cls, + instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT] = ..., + data: Any = ..., + partial: bool = ..., + many: bool = ..., + allow_empty: bool = ..., + context: dict[str, Any] = ..., + read_only: bool = ..., + write_only: bool = ..., + required: bool = ..., + default: Any = ..., + initial: Any = ..., + source: str = ..., + label: str = ..., + help_text: str = ..., + style: dict[str, Any] = ..., + error_messages: dict[str, str] = ..., + validators: Sequence[Validator[Any]] | None = ..., + allow_null: bool = ..., + ) -> ListSerializer[_MT] | Self: ... def update(self, instance: _MT, validated_data: Any) -> _MT: ... # type: ignore[override] def create(self, validated_data: Any) -> _MT: ... # type: ignore[override] def save(self, **kwargs: Any) -> _MT: ... # type: ignore[override] From a58d4dc452ad34d07b7be04326be9c9998da277b Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Tue, 10 Jan 2023 20:45:03 +0100 Subject: [PATCH 3/4] Fix return type of `ModelSerializer` --- rest_framework-stubs/serializers.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework-stubs/serializers.pyi b/rest_framework-stubs/serializers.pyi index 009d6982f..4733e9b8f 100644 --- a/rest_framework-stubs/serializers.pyi +++ b/rest_framework-stubs/serializers.pyi @@ -268,7 +268,7 @@ class ModelSerializer(Serializer, BaseSerializer[_MT]): error_messages: dict[str, str] = ..., validators: Sequence[Validator[Any]] | None = ..., allow_null: bool = ..., - ) -> ListSerializer[_IN]: ... + ) -> ListSerializer[_MT]: ... @overload def __new__( cls: type[Self], From 3e0970ef72439b537ead3771cd2c1792f2990108 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Wed, 11 Jan 2023 00:44:05 +0100 Subject: [PATCH 4/4] wip: Fix overload and add tests overload definition is subject to change, as it is missing some arguments. --- rest_framework-stubs/serializers.pyi | 76 ++++------------------------ tests/typecheck/test_serializers.yml | 18 +++++++ 2 files changed, 28 insertions(+), 66 deletions(-) diff --git a/rest_framework-stubs/serializers.pyi b/rest_framework-stubs/serializers.pyi index 4733e9b8f..f73920f8b 100644 --- a/rest_framework-stubs/serializers.pyi +++ b/rest_framework-stubs/serializers.pyi @@ -86,10 +86,8 @@ class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): @overload def __new__( cls: type[Self], - instance: Iterable[_IN] | None = ..., - data: Any = ..., - partial: bool = ..., - many: Literal[True] = ..., + instance: Iterable[_IN] | None, + many: Literal[True], allow_empty: bool = ..., context: dict[str, Any] = ..., read_only: bool = ..., @@ -109,8 +107,6 @@ class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): def __new__( cls: type[Self], instance: _IN | None = ..., - data: Any = ..., - partial: bool = ..., many: Literal[False] = ..., allow_empty: bool = ..., context: dict[str, Any] = ..., @@ -127,27 +123,6 @@ class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): validators: Sequence[Validator[Any]] | None = ..., allow_null: bool = ..., ) -> Self: ... - def __new__( - cls, - instance: _IN | Iterable[_IN] | None = ..., - data: Any = ..., - partial: bool = ..., - many: bool = ..., - allow_empty: bool = ..., - context: dict[str, Any] = ..., - read_only: bool = ..., - write_only: bool = ..., - required: bool = ..., - default: Any = ..., - initial: Any = ..., - source: str = ..., - label: str = ..., - help_text: str = ..., - style: dict[str, Any] = ..., - error_messages: dict[str, str] = ..., - validators: Sequence[Validator[Any]] | None = ..., - allow_null: bool = ..., - ) -> ListSerializer[_IN] | Self: ... @classmethod def many_init(cls, *args: Any, **kwargs: Any) -> BaseSerializer: ... def is_valid(self, raise_exception: bool = ...) -> bool: ... @@ -248,51 +223,20 @@ class ModelSerializer(Serializer, BaseSerializer[_MT]): depth: int | None extra_kwargs: dict[str, dict[str, Any]] # type: ignore[override] @overload - def __new__( + def __new__( # type: ignore[misc] cls: type[Self], - instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT] = ..., - data: Any = ..., - partial: bool = ..., - many: Literal[True] = ..., - allow_empty: bool = ..., - context: dict[str, Any] = ..., - read_only: bool = ..., - write_only: bool = ..., - required: bool = ..., - default: Any = ..., - initial: Any = ..., - source: str = ..., - label: str = ..., - help_text: str = ..., - style: dict[str, Any] = ..., - error_messages: dict[str, str] = ..., - validators: Sequence[Validator[Any]] | None = ..., - allow_null: bool = ..., + instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT], + many: Literal[True], ) -> ListSerializer[_MT]: ... @overload def __new__( cls: type[Self], - instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT] = ..., - data: Any = ..., - partial: bool = ..., - many: Literal[False] = ..., - allow_empty: bool = ..., - context: dict[str, Any] = ..., - read_only: bool = ..., - write_only: bool = ..., - required: bool = ..., - default: Any = ..., - initial: Any = ..., - source: str = ..., - label: str = ..., - help_text: str = ..., - style: dict[str, Any] = ..., - error_messages: dict[str, str] = ..., - validators: Sequence[Validator[Any]] | None = ..., - allow_null: bool = ..., + instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT], + many: Literal[False], ) -> Self: ... + @overload def __new__( - cls, + cls: type[Self], instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT] = ..., data: Any = ..., partial: bool = ..., @@ -311,7 +255,7 @@ class ModelSerializer(Serializer, BaseSerializer[_MT]): error_messages: dict[str, str] = ..., validators: Sequence[Validator[Any]] | None = ..., allow_null: bool = ..., - ) -> ListSerializer[_MT] | Self: ... + ) -> Self: ... def update(self, instance: _MT, validated_data: Any) -> _MT: ... # type: ignore[override] def create(self, validated_data: Any) -> _MT: ... # type: ignore[override] def save(self, **kwargs: Any) -> _MT: ... # type: ignore[override] diff --git a/tests/typecheck/test_serializers.yml b/tests/typecheck/test_serializers.yml index 59e6d5179..f5fd11472 100644 --- a/tests/typecheck/test_serializers.yml +++ b/tests/typecheck/test_serializers.yml @@ -75,3 +75,21 @@ @cached_property def fields(self) -> BindingDict: return super().fields +- case: test_serializer_many_equals_false + main: | + from rest_framework import serializers + + class TestSerializer(serializers.Serializer[int]): + pass + + test_serializer = TestSerializer(1) + reveal_type(test_serializer) # N: Revealed type is "main.TestSerializer" +- case: test_serializer_many_equals_true + main: | + from rest_framework import serializers + + class TestSerializer(serializers.Serializer[int]): + pass + + test_serializer = TestSerializer(instance=[1, 2], many=True) + reveal_type(test_serializer) # N: Revealed type is "rest_framework.serializers.ListSerializer[builtins.int]"