Skip to content

Commit

Permalink
make generic again
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 23, 2024
1 parent 557ac0f commit 859e0d6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 40 deletions.
76 changes: 42 additions & 34 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
from narwhals.dtypes import DType
from narwhals.functions import ArrowStreamExportable
from narwhals.typing import IntoExpr
from narwhals.typing import IntoSeries

T = TypeVar("T")

Expand All @@ -92,7 +91,7 @@ class DataFrame(NwDataFrame[IntoDataFrameT]):
# annotations are correct.

@property
def _series(self) -> type[Series]:
def _series(self) -> type[Series[Any]]:
return Series

@property
Expand All @@ -106,23 +105,23 @@ def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: tuple[Sequence[int], str]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: tuple[slice, str]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: tuple[Sequence[int], int]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: tuple[slice, int]) -> Series[Any]: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self, item: str) -> Series: ... # type: ignore[overload-overlap]
def __getitem__(self, item: str) -> Series[Any]: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self, item: Sequence[str]) -> Self: ...
Expand Down Expand Up @@ -188,14 +187,16 @@ def lazy(self) -> LazyFrame[Any]:
# Not sure what mypy is complaining about, probably some fancy
# thing that I need to understand category theory for
@overload # type: ignore[override]
def to_dict(self, *, as_series: Literal[True] = ...) -> dict[str, Series]: ...
def to_dict(self, *, as_series: Literal[True] = ...) -> dict[str, Series[Any]]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
@overload
def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, Series[Any]] | dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool = True
) -> dict[str, Series] | dict[str, list[Any]]:
) -> dict[str, Series[Any]] | dict[str, list[Any]]:
"""Convert DataFrame to a dictionary mapping column name to values.
Arguments:
Expand Down Expand Up @@ -242,7 +243,7 @@ def to_dict(
"""
return super().to_dict(as_series=as_series) # type: ignore[return-value]

def is_duplicated(self: Self) -> Series:
def is_duplicated(self: Self) -> Series[Any]:
r"""Get a mask of all duplicated rows in this DataFrame.
Returns:
Expand Down Expand Up @@ -292,7 +293,7 @@ def is_duplicated(self: Self) -> Series:
"""
return super().is_duplicated() # type: ignore[return-value]

def is_unique(self: Self) -> Series:
def is_unique(self: Self) -> Series[Any]:
r"""Get a mask of all unique rows in this DataFrame.
Returns:
Expand Down Expand Up @@ -410,7 +411,7 @@ def _l1_norm(self: Self) -> Self:
return self.select(all()._l1_norm())


class Series(NwSeries[Any]):
class Series(NwSeries[IntoSeriesT]):
"""Narwhals Series, backed by a native series.
The native series might be pandas.Series, polars.Series, ...
Expand Down Expand Up @@ -1150,7 +1151,7 @@ def _stableify(obj: NwDataFrame[IntoFrameT]) -> DataFrame[IntoFrameT]: ...
@overload
def _stableify(obj: NwLazyFrame[IntoFrameT]) -> LazyFrame[IntoFrameT]: ...
@overload
def _stableify(obj: NwSeries[IntoSeriesT]) -> Series: ...
def _stableify(obj: NwSeries[IntoSeriesT]) -> Series[IntoSeriesT]: ...
@overload
def _stableify(obj: NwExpr) -> Expr: ...
@overload
Expand All @@ -1163,7 +1164,7 @@ def _stableify(
| NwSeries[IntoSeriesT]
| NwExpr
| Any,
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series | Expr | Any:
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series[Any] | Expr | Any:
if isinstance(obj, NwDataFrame):
return DataFrame(
obj._compliant_frame,
Expand Down Expand Up @@ -1193,7 +1194,7 @@ def from_native(
eager_or_interchange_only: Literal[True],
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[IntoDataFrameT] | Series: ...
) -> DataFrame[IntoDataFrameT] | Series[IntoSeriesT]: ...


@overload
Expand All @@ -1205,7 +1206,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[IntoDataFrameT] | Series: ...
) -> DataFrame[IntoDataFrameT] | Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -1265,7 +1266,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series: ...
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series[IntoSeriesT]: ...


@overload
Expand All @@ -1277,7 +1278,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[True],
allow_series: None = ...,
) -> Series: ...
) -> Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -1337,7 +1338,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[Any] | LazyFrame[Any] | Series: ...
) -> DataFrame[Any] | LazyFrame[Any] | Series[IntoSeriesT]: ...


@overload
Expand All @@ -1349,7 +1350,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[True],
allow_series: None = ...,
) -> Series: ...
) -> Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -1385,7 +1386,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[IntoDataFrameT] | Series: ...
) -> DataFrame[IntoDataFrameT] | Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -1445,7 +1446,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series: ...
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series[IntoSeriesT]: ...


@overload
Expand All @@ -1457,7 +1458,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[True],
allow_series: None = ...,
) -> Series: ...
) -> Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -1517,7 +1518,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[False] = ...,
allow_series: Literal[True],
) -> DataFrame[Any] | LazyFrame[Any] | Series: ...
) -> DataFrame[Any] | LazyFrame[Any] | Series[IntoSeriesT]: ...


@overload
Expand All @@ -1529,7 +1530,7 @@ def from_native(
eager_or_interchange_only: Literal[False] = ...,
series_only: Literal[True],
allow_series: None = ...,
) -> Series: ...
) -> Series[IntoSeriesT]: ...


@overload
Expand Down Expand Up @@ -1557,16 +1558,16 @@ def from_native(
) -> Any: ...


def from_native(
native_object: IntoFrameT | IntoSeries | T,
def from_native( # type: ignore[misc]
native_object: IntoFrameT | IntoSeriesT | T,
*,
strict: bool | None = None,
pass_through: bool | None = None,
eager_only: bool = False,
eager_or_interchange_only: bool = False,
series_only: bool = False,
allow_series: bool | None = None,
) -> LazyFrame[IntoFrameT] | DataFrame[IntoFrameT] | Series | T:
) -> LazyFrame[IntoFrameT] | DataFrame[IntoFrameT] | Series[IntoSeriesT] | T:
"""Convert `native_object` to Narwhals Dataframe, Lazyframe, or Series.
Arguments:
Expand Down Expand Up @@ -1649,7 +1650,9 @@ def to_native(
narwhals_object: LazyFrame[IntoFrameT], *, strict: Literal[True] = ...
) -> IntoFrameT: ...
@overload
def to_native(narwhals_object: Series, *, strict: Literal[True] = ...) -> Any: ...
def to_native(
narwhals_object: Series[IntoSeriesT], *, strict: Literal[True] = ...
) -> IntoSeriesT: ...
@overload
def to_native(narwhals_object: Any, *, strict: bool) -> Any: ...
@overload
Expand All @@ -1661,17 +1664,22 @@ def to_native(
narwhals_object: LazyFrame[IntoFrameT], *, pass_through: Literal[False] = ...
) -> IntoFrameT: ...
@overload
def to_native(narwhals_object: Series, *, pass_through: Literal[False] = ...) -> Any: ...
def to_native(
narwhals_object: Series[IntoSeriesT], *, pass_through: Literal[False] = ...
) -> IntoSeriesT: ...
@overload
def to_native(narwhals_object: Any, *, pass_through: bool) -> Any: ...


def to_native(
narwhals_object: DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series,
narwhals_object: DataFrame[IntoDataFrameT]
| LazyFrame[IntoFrameT]
| Series[IntoSeriesT]
| Any,
*,
strict: bool | None = None,
pass_through: bool | None = None,
) -> IntoFrameT | Any:
) -> IntoDataFrameT | IntoFrameT | IntoSeriesT | Any:
"""Convert Narwhals object to native one.
Arguments:
Expand Down Expand Up @@ -2999,7 +3007,7 @@ def new_series(
dtype: DType | type[DType] | None = None,
*,
native_namespace: ModuleType,
) -> Series:
) -> Series[Any]:
"""Instantiate Narwhals Series from iterable (e.g. list or array).
Arguments:
Expand Down
12 changes: 6 additions & 6 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def to_native(
@overload
def to_native(
narwhals_object: Series[IntoSeriesT], *, pass_through: Literal[False] = ...
) -> Any: ...
) -> IntoSeriesT: ...
@overload
def to_native(narwhals_object: Any, *, pass_through: bool) -> Any: ...

Expand All @@ -76,7 +76,7 @@ def to_native(
*,
strict: bool | None = None,
pass_through: bool | None = None,
) -> IntoFrameT | Any:
) -> IntoFrameT | IntoSeriesT | Any:
"""Convert Narwhals object to native one.
Arguments:
Expand Down Expand Up @@ -310,16 +310,16 @@ def from_native(
) -> Any: ...


def from_native( # type: ignore[misc]
native_object: IntoFrameT | IntoSeriesT | T,
def from_native(
native_object: IntoFrameT | IntoSeriesT | Any,
*,
strict: bool | None = None,
pass_through: bool | None = None,
eager_only: bool = False,
eager_or_interchange_only: bool = False,
series_only: bool = False,
allow_series: bool | None = None,
) -> LazyFrame[IntoFrameT] | DataFrame[IntoFrameT] | Series[IntoSeriesT] | T:
) -> LazyFrame[IntoFrameT] | DataFrame[IntoFrameT] | Series[IntoSeriesT] | Any:
"""Convert `native_object` to Narwhals Dataframe, Lazyframe, or Series.
Arguments:
Expand Down Expand Up @@ -376,7 +376,7 @@ def from_native( # type: ignore[misc]
strict, pass_through, pass_through_default=False, emit_deprecation_warning=True
)

return _from_native_impl( # type: ignore[no-any-return]
return _from_native_impl(
native_object,
pass_through=pass_through,
eager_only=eager_only,
Expand Down

0 comments on commit 859e0d6

Please sign in to comment.