Skip to content
Merged
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pyotp = "==2.9.0"
psycopg2-binary = "==2.9.9"
redis = {version = "==5.2.1", extras = ["hiredis"]}
regex = "==2024.11.6"
requests = "==2.33.1"
requests = "==2.34.2"
pyjwt = "==2.12.1"
psutil = "==7.0.0"
google-auth = "==2.48.0"
Expand Down
306 changes: 153 additions & 153 deletions Pipfile.lock

Large diffs are not rendered by default.

34 changes: 28 additions & 6 deletions codeforlife/models/encrypted.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,21 @@ class Manager(
that would bypass field-level encryption.
"""

def _is_encrypted_field(self, field_name: str):
return any(
field.name == field_name
for field in self.model.ENCRYPTED_FIELDS
)

def _is_none_or_empty(self, value: t.Any):
return value is None or value == b""

def update(self, **kwargs):
"""Ensure encrypted fields are not updated via 'update()'."""
for name in kwargs:
if any(
field.name == name for field in self.model.ENCRYPTED_FIELDS
):
for name, value in kwargs.items():
if self._is_encrypted_field(
name
) and not self._is_none_or_empty(value):
raise ValidationError(
f"Cannot update encrypted field '{name}' via"
" 'update()'. Set the property on each instance"
Expand All @@ -91,9 +100,22 @@ def update(self, **kwargs):

return super().update(**kwargs)

def bulk_update(self, objs, fields, batch_size=None):
"""Ensure encrypted fields are not updated via 'bulk_update()'."""
for name in fields:
if self._is_encrypted_field(name) and not all(
self._is_none_or_empty(getattr(obj, name)) for obj in objs
):
raise ValidationError(
f"Cannot bulk update encrypted field '{name}' via"
" 'bulk_update()'. Set the property on each instance"
" instead.",
code="cannot_bulk_update",
)

return super().bulk_update(objs, fields, batch_size)

# Disable bulk operations that would bypass field-level encryption.
bulk_update: t.Never = None # type: ignore[assignment]
abulk_update: t.Never = None # type: ignore[assignment]
bulk_create: t.Never = None # type: ignore[assignment]
abulk_create: t.Never = None # type: ignore[assignment]
in_bulk: t.Never = None # type: ignore[assignment]
Expand Down
14 changes: 8 additions & 6 deletions codeforlife/models/encrypted_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Person(EncryptedModel):

name = EncryptedTextField(associated_data="name")

objects: EncryptedModel.Manager["Person"] # type: ignore[assignment]

class Meta(TypedModelMeta):
app_label = "codeforlife.user"

Expand All @@ -37,13 +39,13 @@ def test_objects___update__cannot_update(self):
with self.assert_raises_validation_error(code="cannot_update"):
Person.objects.update(name="Alice")

def test_objects___bulk_update(self):
def test_objects___bulk_update__cannot_bulk_update(self):
"""Cannot bulk update encrypted field via objects.bulk_update()."""
assert Person.objects.bulk_update is None

def test_objects___abulk_update(self):
"""Cannot abulk_update encrypted field via objects.abulk_update()."""
assert Person.objects.abulk_update is None
with self.assert_raises_validation_error(code="cannot_bulk_update"):
Person.objects.bulk_update(
[Person(name="Alice"), Person(name="Bob")],
fields=["name"],
)

def test_objects___bulk_create(self):
"""Cannot bulk create encrypted field via objects.bulk_create()."""
Expand Down
1 change: 1 addition & 0 deletions codeforlife/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from .data_encryption_key import DataEncryptionKeyField
from .deferred_attribute import DeferredAttribute
from .encrypted_text import EncryptedTextField
from .normalized import NormalizedField
from .sha256 import Sha256Field
51 changes: 39 additions & 12 deletions codeforlife/models/fields/base_encrypted.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..encrypted import EncryptedModel
from ..utils import is_real_model_class
from .deferred_attribute import DeferredAttribute
from .normalized import Normalize, NormalizedField

T = t.TypeVar("T")
Ciphertext: t.TypeAlias = t.Union[bytes, memoryview]
Expand All @@ -51,22 +52,35 @@ def __set__(self, instance, value):
super().__set__(instance, value)


class BaseEncryptedField(BinaryField, t.Generic[T]):
class BaseEncryptedField(
NormalizedField[EncryptedModel, T], BinaryField, t.Generic[T]
):
"""Binary field base class for storing encrypted typed values."""

model: t.Type[EncryptedModel]

descriptor_class = EncryptedAttribute

def __init__(self, associated_data: str, **kwargs):
def __init__(
self,
associated_data: str,
normalize: None | Normalize[T] = None,
unique: t.Literal[False] = False,
**kwargs,
):
if unique:
raise ValidationError(
f"{self.__class__.__name__} does not support unique=True.",
code="unique_not_supported",
)
if not associated_data:
raise ValidationError(
"Associated data cannot be empty.",
code="no_associated_data",
)
self.associated_data = associated_data

super().__init__(**kwargs)
super().__init__(normalize=normalize, unique=unique, **kwargs)

def deconstruct(self):
name, path, args, kwargs = t.cast(
Expand Down Expand Up @@ -182,18 +196,28 @@ def full_associated_data(self):

def _decrypt(self, instance: EncryptedModel, ciphertext: bytes):
"""Decrypts a single value using the DEK and associated data."""
data = instance.dek_aead.decrypt(
ciphertext=ciphertext,
associated_data=self.full_associated_data,
data = (
b""
if ciphertext == b""
else instance.dek_aead.decrypt(
ciphertext=ciphertext,
associated_data=self.full_associated_data,
)
)

return self.bytes_to_value(data)

def _encrypt(self, instance: EncryptedModel, plaintext: T):
def _encrypt(self, instance: EncryptedModel, value: T):
"""Encrypts a single value using the DEK and associated data."""
return instance.dek_aead.encrypt(
plaintext=self.value_to_bytes(plaintext),
associated_data=self.full_associated_data,
plaintext = self.value_to_bytes(value)

return (
b""
if plaintext == b""
else instance.dek_aead.encrypt(
plaintext=plaintext,
associated_data=self.full_associated_data,
)
)

@staticmethod
Expand Down Expand Up @@ -239,8 +263,8 @@ def get(instance: EncryptedModel, field_name: str):

return decrypted_value

@staticmethod
def set(instance: EncryptedModel, value: t.Optional[T], field_name: str):
@classmethod
def set(cls, instance, value, field_name, **kwargs):
"""Set a typed plaintext value for an encrypted field.

The plaintext is staged in pending-encryption storage and encrypted at
Expand All @@ -250,6 +274,7 @@ def set(instance: EncryptedModel, value: t.Optional[T], field_name: str):
instance: The model instance on which to set the value.
value: The plaintext value to set. If None, the field is cleared.
field_name: The name of the encrypted field to set.
normalize: Whether to normalize the value before setting it.
"""
field = t.cast(
BaseEncryptedField[T], instance._meta.get_field(field_name)
Expand All @@ -259,6 +284,8 @@ def set(instance: EncryptedModel, value: t.Optional[T], field_name: str):
if value is None:
instance.__pending_encryption_values__.pop(field.attname, None)
else:
if kwargs.get("normalize", True) and field.normalize is not None:
value = field.normalize(value)
instance.__pending_encryption_values__[field.attname] = value

# In all cases we need to clear the internal and cached-decrypted value.
Expand Down
8 changes: 8 additions & 0 deletions codeforlife/models/fields/base_encrypted_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def test_init__no_associated_data(self):
with self.assert_raises_validation_error(code="no_associated_data"):
BaseEncryptedField(associated_data="")

def test_init__unique_not_supported(self):
"""Cannot create BaseEncryptedField with unique=True."""
with self.assert_raises_validation_error(code="unique_not_supported"):
BaseEncryptedField(
associated_data="test",
unique=True, # type: ignore[arg-type]
)

def test_init(self):
"""BaseEncryptedField is constructed correctly."""
assert self.field.associated_data == self.field_associated_data
Expand Down
42 changes: 42 additions & 0 deletions codeforlife/models/fields/normalized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
© Ocado Group
Created on 18/05/2026 at 15:38:05(+01:00).
"""

import typing as t

from django.db.models import Field, Model

AnyModel = t.TypeVar("AnyModel", bound=Model)
T = t.TypeVar("T")
Normalize: t.TypeAlias = t.Callable[[T], T]


class NormalizedField(Field, t.Generic[AnyModel, T]):
"""A Django model field that normalizes values before saving."""

def __init__(self, normalize: None | Normalize[T], *args, **kwargs):
super().__init__(*args, **kwargs)
self.normalize = normalize

@classmethod
def set(
cls, instance: AnyModel, value: None | T, field_name: str, **kwargs
):
"""
Normalize and assign a value to a NormalizedField.

Args:
instance: The model instance on which to set the value.
value: The value to normalize and set.
field_name: The name of the NormalizedField on the model.
"""
if value is not None:
field = t.cast(
NormalizedField[AnyModel, T],
instance._meta.get_field(field_name),
)
if field.normalize is not None:
value = field.normalize(value)

setattr(instance, field_name, value)
Loading
Loading