Skip to content

Commit

Permalink
feat: add Series|Expr.rolling_sum method (#1395)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Nov 18, 2024
1 parent 836f086 commit bbf2aa3
Show file tree
Hide file tree
Showing 13 changed files with 766 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/api-reference/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
- ColumnNotFoundError
- InvalidIntoExprError
- InvalidOperationError
- NarwhalsUnstableWarning
show_source: false
show_bases: false
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
- pipe
- quantile
- replace_strict
- rolling_sum
- round
- sample
- shift
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
- quantile
- rename
- replace_strict
- rolling_sum
- round
- sample
- scatter
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
def cum_prod(self: Self, *, reverse: bool) -> Self:
return reuse_series_implementation(self, "cum_prod", reverse=reverse)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_sum",
window_size=window_size,
min_periods=min_periods,
center=center,
)

@property
def dt(self: Self) -> ArrowExprDateTimeNamespace:
return ArrowExprDateTimeNamespace(self)
Expand Down
48 changes: 48 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,54 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
)
return self._from_native_series(result)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

min_periods = min_periods if min_periods is not None else window_size
if center:
offset_left = window_size // 2
offset_right = offset_left - (
window_size % 2 == 0
) # subtract one if window_size is even

native_series = self._native_series

pad_left = pa.array([None] * offset_left, type=native_series.type)
pad_right = pa.array([None] * offset_right, type=native_series.type)
padded_arr = self._from_native_series(
pa.concat_arrays([pad_left, native_series.combine_chunks(), pad_right])
)
else:
padded_arr = self

cum_sum = padded_arr.cum_sum(reverse=False).fill_null(strategy="forward")
rolling_sum = (
cum_sum - cum_sum.shift(window_size).fill_null(0)
if window_size != 0
else cum_sum
)

valid_count = padded_arr.cum_count(reverse=False)
count_in_window = valid_count - valid_count.shift(window_size).fill_null(0)

result = self._from_native_series(
pc.if_else(
(count_in_window >= min_periods)._native_series,
rolling_sum._native_series,
None,
)
)
if center:
result = result[offset_left + offset_right :]
return result

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
26 changes: 26 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,32 @@ def is_finite(self: Self) -> Self:
returns_scalar=False,
)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
def func(
_input: dask_expr.Series,
_window: int,
_min_periods: int | None,
_center: bool, # noqa: FBT001
) -> dask_expr.Series:
return _input.rolling(
window=_window, min_periods=_min_periods, center=_center
).sum()

return self._from_call(
func,
"rolling_sum",
window_size,
min_periods,
center,
returns_scalar=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,21 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
def cum_prod(self: Self, *, reverse: bool) -> Self:
return reuse_series_implementation(self, "cum_prod", reverse=reverse)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
return reuse_series_implementation(
self,
"rolling_sum",
window_size=window_size,
min_periods=min_periods,
center=center,
)

@property
def str(self: Self) -> PandasLikeExprStringNamespace:
return PandasLikeExprStringNamespace(self)
Expand Down
14 changes: 13 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def fill_null(
value: Any | None = None,
strategy: Literal["forward", "backward"] | None = None,
limit: int | None = None,
) -> PandasLikeSeries:
) -> Self:
ser = self._native_series
if value is not None:
res_ser = self._from_native_series(ser.fillna(value=value))
Expand Down Expand Up @@ -798,6 +798,18 @@ def cum_prod(self: Self, *, reverse: bool) -> Self:
)
return self._from_native_series(result)

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None,
center: bool,
) -> Self:
result = self._native_series.rolling(
window=window_size, min_periods=min_periods, center=center
).sum()
return self._from_native_series(result)

def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

Expand Down
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ def from_invalid_type(cls, invalid_type: type) -> InvalidIntoExprError:
" column with literal value `0`."
)
return InvalidIntoExprError(message)


class NarwhalsUnstableWarning(UserWarning):
"""Warning issued when a method or function is considered unstable in the stable api."""
118 changes: 118 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TypeVar

from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import InvalidOperationError
from narwhals.utils import flatten

if TYPE_CHECKING:
Expand Down Expand Up @@ -2990,6 +2991,123 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self:
"""
return self.__class__(lambda plx: self._call(plx).cum_prod(reverse=reverse))

def rolling_sum(
self: Self,
window_size: int,
*,
min_periods: int | None = None,
center: bool = False,
) -> Self:
"""Apply a rolling sum (moving sum) over the values.
!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.
A window of length `window_size` will traverse the values. The resulting values
will be aggregated to their sum.
The window at a given row will include the row itself and the `window_size - 1`
elements before it.
Arguments:
window_size: The length of the window in number of elements. It must be a
strictly positive integer.
min_periods: The number of values in the window that should be non-null before
computing a result. If set to `None` (default), it will be set equal to
`window_size`. If provided, it must be a strictly positive integer, and
less than or equal to `window_size`
center: Set the labels at the center of the window.
Returns:
A new expression.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {"a": [1.0, 2.0, None, 4.0]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
>>> df_pa = pa.table(data)
We define a library agnostic function:
>>> @nw.narwhalify
... def func(df):
... return df.with_columns(
... b=nw.col("a").rolling_sum(window_size=3, min_periods=1)
... )
We can then pass any supported library such as Pandas, Polars, or PyArrow to `func`:
>>> func(df_pd)
a b
0 1.0 1.0
1 2.0 3.0
2 NaN 3.0
3 4.0 6.0
>>> func(df_pl)
shape: (4, 2)
┌──────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ f64 ┆ f64 │
╞══════╪═════╡
│ 1.0 ┆ 1.0 │
│ 2.0 ┆ 3.0 │
│ null ┆ 3.0 │
│ 4.0 ┆ 6.0 │
└──────┴─────┘
>>> func(df_pa) # doctest:+ELLIPSIS
pyarrow.Table
a: double
b: double
----
a: [[1,2,null,4]]
b: [[1,3,3,6]]
"""
if window_size < 1:
msg = "window_size must be greater or equal than 1"
raise ValueError(msg)

if not isinstance(window_size, int):
_type = window_size.__class__.__name__
msg = (
f"argument 'window_size': '{_type}' object cannot be "
"interpreted as an integer"
)
raise TypeError(msg)

if min_periods is not None:
if min_periods < 1:
msg = "min_periods must be greater or equal than 1"
raise ValueError(msg)

if not isinstance(min_periods, int):
_type = min_periods.__class__.__name__
msg = (
f"argument 'min_periods': '{_type}' object cannot be "
"interpreted as an integer"
)
raise TypeError(msg)
if min_periods > window_size:
msg = "`min_periods` must be less or equal than `window_size`"
raise InvalidOperationError(msg)
else:
min_periods = window_size

return self.__class__(
lambda plx: self._call(plx).rolling_sum(
window_size=window_size,
min_periods=min_periods,
center=center,
)
)

@property
def str(self: Self) -> ExprStringNamespace[Self]:
return ExprStringNamespace(self)
Expand Down
Loading

0 comments on commit bbf2aa3

Please sign in to comment.