Skip to content

Commit

Permalink
implement or and ior operators (pallets#2979)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism authored Oct 30, 2024
2 parents 862cb19 + b65b587 commit 1eb7ada
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Unreleased
:issue:`2970`
- ``MultiDict.getlist`` catches ``TypeError`` in addition to ``ValueError``
when doing type conversion. :issue:`2976`
- Implement ``|`` and ``|=`` operators for ``MultiDict``, ``Headers``, and
``CallbackDict``, and disallow ``|=`` on immutable types. :issue:`2977`


Version 3.0.6
Expand Down
31 changes: 31 additions & 0 deletions src/werkzeug/datastructures/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class Headers(cabc.MutableMapping[str, str]):
:param defaults: The list of default values for the :class:`Headers`.
.. versionchanged:: 3.1
Implement ``|`` and ``|=`` operators.
.. versionchanged:: 2.1.0
Default values are validated the same as values added later.
Expand Down Expand Up @@ -524,6 +527,31 @@ def update( # type: ignore[override]
else:
self.set(key, value)

def __or__(
self, other: cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
) -> te.Self:
if not isinstance(other, cabc.Mapping):
return NotImplemented

rv = self.copy()
rv.update(other)
return rv

def __ior__(
self,
other: (
cabc.Mapping[str, t.Any | cabc.Collection[t.Any]]
| cabc.Iterable[tuple[str, t.Any]]
),
) -> te.Self:
if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance(
other, str
):
return NotImplemented

self.update(other)
return self

def to_wsgi_list(self) -> list[tuple[str, str]]:
"""Convert the headers into a list suitable for WSGI.
Expand Down Expand Up @@ -620,6 +648,9 @@ def __iter__(self) -> cabc.Iterator[tuple[str, str]]: # type: ignore[override]
def copy(self) -> t.NoReturn:
raise TypeError(f"cannot create {type(self).__name__!r} copies")

def __or__(self, other: t.Any) -> t.NoReturn:
raise TypeError(f"cannot create {type(self).__name__!r} copies")


# circular dependencies
from .. import http
21 changes: 21 additions & 0 deletions src/werkzeug/datastructures/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def sort(self, key: t.Any = None, reverse: t.Any = False) -> t.NoReturn:
class ImmutableDictMixin(t.Generic[K, V]):
"""Makes a :class:`dict` immutable.
.. versionchanged:: 3.1
Disallow ``|=`` operator.
.. versionadded:: 0.5
:private:
Expand Down Expand Up @@ -117,6 +120,9 @@ def setdefault(self, key: t.Any, default: t.Any = None) -> t.NoReturn:
def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
_immutable_error(self)

def __ior__(self, other: t.Any) -> t.NoReturn:
_immutable_error(self)

def pop(self, key: t.Any, default: t.Any = None) -> t.NoReturn:
_immutable_error(self)

Expand Down Expand Up @@ -168,6 +174,9 @@ class ImmutableHeadersMixin:
hashable though since the only usecase for this datastructure
in Werkzeug is a view on a mutable structure.
.. versionchanged:: 3.1
Disallow ``|=`` operator.
.. versionadded:: 0.5
:private:
Expand Down Expand Up @@ -200,6 +209,9 @@ def extend(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn:
_immutable_error(self)

def __ior__(self, other: t.Any) -> t.NoReturn:
_immutable_error(self)

def insert(self, pos: t.Any, value: t.Any) -> t.NoReturn:
_immutable_error(self)

Expand Down Expand Up @@ -233,6 +245,9 @@ def wrapper(
class UpdateDictMixin(dict[K, V]):
"""Makes dicts call `self.on_update` on modifications.
.. versionchanged:: 3.1
Implement ``|=`` operator.
.. versionadded:: 0.5
:private:
Expand Down Expand Up @@ -294,3 +309,9 @@ def update( # type: ignore[override]
super().update(**kwargs)
else:
super().update(arg, **kwargs)

@_always_update
def __ior__( # type: ignore[override]
self, other: cabc.Mapping[K, V] | cabc.Iterable[tuple[K, V]]
) -> te.Self:
return super().__ior__(other)
25 changes: 25 additions & 0 deletions src/werkzeug/datastructures/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ class MultiDict(TypeConversionDict[K, V]):
:param mapping: the initial value for the :class:`MultiDict`. Either a
regular dict, an iterable of ``(key, value)`` tuples
or `None`.
.. versionchanged:: 3.1
Implement ``|`` and ``|=`` operators.
"""

def __init__(
Expand Down Expand Up @@ -435,6 +438,28 @@ def update( # type: ignore[override]
for key, value in iter_multi_items(mapping):
self.add(key, value)

def __or__( # type: ignore[override]
self, other: cabc.Mapping[K, V | cabc.Collection[V]]
) -> MultiDict[K, V]:
if not isinstance(other, cabc.Mapping):
return NotImplemented

rv = self.copy()
rv.update(other)
return rv

def __ior__( # type: ignore[override]
self,
other: cabc.Mapping[K, V | cabc.Collection[V]] | cabc.Iterable[tuple[K, V]],
) -> te.Self:
if not isinstance(other, (cabc.Mapping, cabc.Iterable)) or isinstance(
other, str
):
return NotImplemented

self.update(other)
return self

@t.overload
def pop(self, key: K) -> V: ...
@t.overload
Expand Down
48 changes: 47 additions & 1 deletion tests/test_datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,17 @@ def test_basic_interface(self):
md.setlist("foo", [1, 2])
assert md.getlist("foo") == [1, 2]

def test_or(self) -> None:
a = self.storage_class({"x": 1})
b = a | {"y": 2}
assert isinstance(b, self.storage_class)
assert "x" in b and "y" in b

def test_ior(self) -> None:
a = self.storage_class({"x": 1})
a |= {"y": 2}
assert "x" in a and "y" in a


class _ImmutableDictTests:
storage_class: type[dict]
Expand Down Expand Up @@ -305,6 +316,17 @@ def test_dict_is_hashable(self):
assert immutable in x
assert immutable2 in x

def test_or(self) -> None:
a = self.storage_class({"x": 1})
b = a | {"y": 2}
assert "x" in b and "y" in b

def test_ior(self) -> None:
a = self.storage_class({"x": 1})

with pytest.raises(TypeError):
a |= {"y": 2}


class TestImmutableTypeConversionDict(_ImmutableDictTests):
storage_class = ds.ImmutableTypeConversionDict
Expand Down Expand Up @@ -799,6 +821,17 @@ def test_equality(self):

assert h1 == h2

def test_or(self) -> None:
a = ds.Headers({"x": 1})
b = a | {"y": 2}
assert isinstance(b, ds.Headers)
assert "x" in b and "y" in b

def test_ior(self) -> None:
a = ds.Headers({"x": 1})
a |= {"y": 2}
assert "x" in a and "y" in a


class TestEnvironHeaders:
storage_class = ds.EnvironHeaders
Expand Down Expand Up @@ -840,6 +873,18 @@ def test_return_type_is_str(self):
assert headers["Foo"] == "\xe2\x9c\x93"
assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93")

def test_or(self) -> None:
headers = ds.EnvironHeaders({"x": "1"})

with pytest.raises(TypeError):
headers | {"y": "2"}

def test_ior(self) -> None:
headers = ds.EnvironHeaders({})

with pytest.raises(TypeError):
headers |= {"y": "2"}


class TestHeaderSet:
storage_class = ds.HeaderSet
Expand Down Expand Up @@ -927,7 +972,7 @@ def test_callback_dict_writes(self):
assert_calls, func = make_call_asserter()
initial = {"a": "foo", "b": "bar"}
dct = self.storage_class(initial=initial, on_update=func)
with assert_calls(8, "callback not triggered by write method"):
with assert_calls(9, "callback not triggered by write method"):
# always-write methods
dct["z"] = 123
dct["z"] = 123 # must trigger again
Expand All @@ -937,6 +982,7 @@ def test_callback_dict_writes(self):
dct.popitem()
dct.update([])
dct.clear()
dct |= {}
with assert_calls(0, "callback triggered by failed del"):
pytest.raises(KeyError, lambda: dct.__delitem__("x"))
with assert_calls(0, "callback triggered by failed pop"):
Expand Down

0 comments on commit 1eb7ada

Please sign in to comment.