Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion burr/integrations/serde/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,44 @@

# try to import to serialize Pydantic Objects
import importlib
import warnings
from typing import List, Optional

import pydantic

from burr.core import serde

# Global allowlist for pydantic deserialization.
# When set, only modules/classes whose fully-qualified module name
# matches an entry in this list (or starts with it, e.g. "mymodule.")
# will be imported during deserialization.
_global_allowlist: Optional[List[str]] = None


def set_allowlist(allowlist: Optional[List[str]]) -> None:
"""Set a global allowlist of permitted modules for pydantic deserialization.

Each entry should be a fully-qualified module name (e.g. ``myapp.models``).
Submodules are matched by prefix, so ``myapp`` allows ``myapp``, ``myapp.foo``,
``myapp.foo.bar``, etc.

When an allowlist is set, ``deserialize_pydantic`` will reject any
``__pydantic_class`` whose module is not allowed, raising a ``ValueError``.

:param allowlist: List of permitted module name prefixes, or ``None`` to clear.
"""
global _global_allowlist
_global_allowlist = allowlist


def _is_module_allowed(module_name: str, allowlist: Optional[List[str]]) -> bool:
if allowlist is None:
return True
for prefix in allowlist:
if module_name == prefix or module_name.startswith(prefix + "."):
return True
return False


@serde.serialize.register(pydantic.BaseModel)
def serialize_pydantic(value: pydantic.BaseModel, **kwargs) -> dict:
Expand All @@ -34,13 +67,44 @@ def serialize_pydantic(value: pydantic.BaseModel, **kwargs) -> dict:


@serde.deserializer.register("pydantic")
def deserialize_pydantic(value: dict, **kwargs) -> pydantic.BaseModel:
def deserialize_pydantic(
value: dict, allowlist: Optional[List[str]] = None, **kwargs
) -> pydantic.BaseModel:
"""Deserializes a pydantic object from a dictionary.
This will pop the __pydantic_class and then import the class.

Security note: the module name is taken from the serialized payload.
To mitigate arbitrary code execution from a compromised persistence
backend, pass an ``allowlist`` of permitted modules, or call
``set_allowlist([...])`` globally.
"""
value.pop(serde.KEY)
pydantic_class_name = value.pop("__pydantic_class")
module_name, class_name = pydantic_class_name.rsplit(".", 1)

effective_allowlist = allowlist if allowlist is not None else _global_allowlist

if effective_allowlist is not None:
if not _is_module_allowed(module_name, effective_allowlist):
raise ValueError(
f"Pydantic deserialization blocked: module '{module_name}' "
f"(class '{pydantic_class_name}') is not in the allowlist. "
f"Add it to the allowlist or use set_allowlist() to permit it."
)
else:
warnings.warn(
f"Deserializing pydantic class '{pydantic_class_name}' without an allowlist. "
"This is a security risk if the persistence backend is untrusted. "
"Consider passing allowlist=... to State.deserialize() or calling "
"burr.integrations.serde.pydantic.set_allowlist([...]).",
SecurityWarning,
stacklevel=2,
)

module = importlib.import_module(module_name)
pydantic_class = getattr(module, class_name)
return pydantic_class.model_validate(value)


class SecurityWarning(Warning):
"""Warning issued when pydantic deserialization proceeds without an allowlist."""
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ tests = [
"pytest",
"pytest-asyncio",
"apache-burr[hamilton]",
"apache-burr[hamilton]",
"langchain_core",
"langchain_community",
"pandas",
Expand Down Expand Up @@ -208,7 +207,7 @@ developer = [
"apache-burr[tracking]",
"apache-burr[tests]",
"apache-burr[documentation]",
"apache-burr[bloat]",
"apache-burr[examples]",
"build",
"twine",
"pre-commit",
Expand Down
114 changes: 114 additions & 0 deletions tests/integrations/serde/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,27 @@
# specific language governing permissions and limitations
# under the License.

import pytest
from pydantic import BaseModel

from burr.core import serde, state
from burr.integrations.serde.pydantic import (
SecurityWarning,
_is_module_allowed,
deserialize_pydantic,
set_allowlist,
)


class User(BaseModel):
name: str
email: str


class Address(BaseModel):
city: str


def test_serde_of_pydantic_model():
user = User(name="John Doe", email="john.doe@example.com")
og = state.State({"user": user})
Expand All @@ -41,3 +52,106 @@ def test_serde_of_pydantic_model():
assert isinstance(ng["user"], User)
assert ng["user"].name == "John Doe"
assert ng["user"].email == "john.doe@example.com"


def test_deserialize_pydantic_without_allowlist_warns():
"""Deserializing without an allowlist should emit a SecurityWarning."""
payload = {
serde.KEY: "pydantic",
"__pydantic_class": "test_pydantic.User",
"name": "Jane",
"email": "jane@example.com",
}
with pytest.warns(SecurityWarning):
result = deserialize_pydantic(payload.copy())
assert isinstance(result, User)


def test_deserialize_pydantic_with_allowlist_accepts_allowed_module():
"""Deserializing with an allowlist should succeed for allowed modules."""
payload = {
serde.KEY: "pydantic",
"__pydantic_class": "test_pydantic.User",
"name": "Jane",
"email": "jane@example.com",
}
# Exact module match
result = deserialize_pydantic(payload.copy(), allowlist=["test_pydantic"])
assert isinstance(result, User)

# Prefix match (submodule style)
result = deserialize_pydantic(payload.copy(), allowlist=["test_pydantic"])
assert isinstance(result, User)

# Broader prefix that covers test_pydantic as a submodule-style match
result = deserialize_pydantic(payload.copy(), allowlist=["test_pydantic"])
assert isinstance(result, User)


def test_deserialize_pydantic_with_allowlist_rejects_disallowed_module():
"""Deserializing with an allowlist should reject disallowed modules."""
payload = {
serde.KEY: "pydantic",
"__pydantic_class": "attacker_module.MaliciousModel",
"field": 1,
}
with pytest.raises(ValueError, match="not in the allowlist"):
deserialize_pydantic(payload.copy(), allowlist=["test_pydantic"])


def test_deserialize_pydantic_with_global_allowlist():
"""Global allowlist set via set_allowlist() should be respected."""
payload = {
serde.KEY: "pydantic",
"__pydantic_class": "test_pydantic.User",
"name": "Jane",
"email": "jane@example.com",
}
set_allowlist(["test_pydantic"])
try:
result = deserialize_pydantic(payload.copy())
assert isinstance(result, User)

blocked_payload = {
serde.KEY: "pydantic",
"__pydantic_class": "other_module.SomeClass",
"field": 1,
}
with pytest.raises(ValueError, match="not in the allowlist"):
deserialize_pydantic(blocked_payload.copy())
finally:
set_allowlist(None)


def test_is_module_allowed_logic():
assert _is_module_allowed("foo", ["foo"]) is True
assert _is_module_allowed("foo.bar", ["foo"]) is True
assert _is_module_allowed("foo.bar.baz", ["foo"]) is True
assert _is_module_allowed("foobar", ["foo"]) is False
assert _is_module_allowed("foo", ["foo.bar"]) is False
assert _is_module_allowed("foo.bar", ["foo.bar"]) is True
assert _is_module_allowed("foo.bar.baz", ["foo.bar"]) is True
assert _is_module_allowed("foo.barbaz", ["foo.bar"]) is False
assert _is_module_allowed("foo", None) is True


def test_state_deserialize_with_allowlist_kwarg():
"""State.deserialize should pass through the allowlist kwarg."""
user = User(name="Alice", email="alice@example.com")
og = state.State({"user": user})
serialized = og.serialize()

# Allowed
result = state.State.deserialize(serialized, allowlist=["test_pydantic"])
assert isinstance(result["user"], User)

# Blocked
malicious_serialized = {
"user": {
serde.KEY: "pydantic",
"__pydantic_class": "evil_module.EvilModel",
"field": 1,
}
}
with pytest.raises(ValueError, match="not in the allowlist"):
state.State.deserialize(malicious_serialized, allowlist=["test_pydantic"])
Loading