diff --git a/src/storage/local.py b/src/storage/local.py index 9935daf..8e2a820 100644 --- a/src/storage/local.py +++ b/src/storage/local.py @@ -13,10 +13,22 @@ import sqlite3 import uuid from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence from src.config import settings from src.storage.base import BaseVectorStore, IndexStats, SearchResult +from src.storage.memory_lifecycle import ( + CONTENT_HASH_KEY, + FORGET_REASON_KEY, + FORGOTTEN_AT_KEY, + IS_CURRENT_KEY, + PARENT_MEMORY_ID_KEY, + VERSION_KEY, + build_lifecycle_metadata, + compute_memory_hash, + is_retrievable_memory, + utc_now_iso, +) from src.utils.exceptions import VectorStoreValidationError @@ -43,6 +55,9 @@ def _cosine_similarity(a: Sequence[float], b: Sequence[float]) -> float: return max(0.0, min(1.0, (dot / (norm_a * norm_b) + 1.0) / 2.0)) +_DEDUP_SCOPE_KEYS = ("user_id", "tenant_id", "org_id", "workspace_id", "project_id") + + class SQLiteVectorStore(BaseVectorStore): """Small embedded vector store for single-user local testing. @@ -101,16 +116,29 @@ def add( ids = ids or [str(uuid.uuid4()) for _ in texts] metadata = metadata or [{} for _ in texts] - rows = [ - ( - self._namespace, - vec_id, - text, - json.dumps([float(v) for v in embedding]), - json.dumps(meta or {}), + output_ids: List[str] = [] + rows = [] + for text, embedding, vec_id, meta in zip(texts, embeddings, ids, metadata): + lifecycle_meta = build_lifecycle_metadata(text, meta) + existing_id = self._find_current_by_hash( + lifecycle_meta[CONTENT_HASH_KEY], + lifecycle_meta, ) - for text, embedding, vec_id, meta in zip(texts, embeddings, ids, metadata) - ] + if existing_id: + output_ids.append(existing_id) + continue + output_ids.append(vec_id) + rows.append( + ( + self._namespace, + vec_id, + text, + json.dumps([float(v) for v in embedding]), + json.dumps(lifecycle_meta), + ) + ) + if not rows: + return output_ids self._conn.executemany( """ INSERT INTO xmem_vectors(namespace, id, content, embedding, metadata) @@ -124,7 +152,7 @@ def add( rows, ) self._conn.commit() - return ids + return output_ids def search( self, @@ -145,7 +173,7 @@ def search( results: List[SearchResult] = [] for row in rows: meta = json.loads(row["metadata"] or "{}") - if not _metadata_matches(meta, filters): + if not is_retrievable_memory(meta) or not _metadata_matches(meta, filters): continue embedding = json.loads(row["embedding"]) results.append( @@ -175,6 +203,11 @@ def update( return False current_meta = json.loads(row["metadata"] or "{}") current_meta.update(metadata or {}) + new_text = text if text is not None else row["content"] + current_meta[CONTENT_HASH_KEY] = compute_memory_hash(new_text) + existing_id = self._find_current_by_hash(current_meta[CONTENT_HASH_KEY], current_meta) + if existing_id and existing_id != id: + return False new_embedding = embedding if embedding is not None else json.loads(row["embedding"]) if len(new_embedding) != self._dimension: raise VectorStoreValidationError( @@ -188,7 +221,7 @@ def update( WHERE namespace = ? AND id = ? """, ( - text if text is not None else row["content"], + new_text, json.dumps([float(v) for v in new_embedding]), json.dumps(current_meta), self._namespace, @@ -198,6 +231,117 @@ def update( self._conn.commit() return True + def add_version( + self, + parent_id: str, + text: str, + embedding: List[float], + id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Optional[str]: + """Create a new current memory version and keep the parent as history.""" + + parent = self.get([parent_id]) + if not parent: + return None + if not is_retrievable_memory(parent[0]["metadata"] or {}): + return None + if len(embedding) != self._dimension: + raise VectorStoreValidationError( + f"Embedding dimension {len(embedding)} doesn't match {self._dimension}", + operation="add_version", + ) + + parent_meta = dict(parent[0]["metadata"] or {}) + root_parent_id = parent_meta.get(PARENT_MEMORY_ID_KEY) or parent_id + next_version = int(parent_meta.get(VERSION_KEY) or 1) + 1 + new_id = id or str(uuid.uuid4()) + new_meta = build_lifecycle_metadata( + text, + metadata, + parent_memory_id=root_parent_id, + version=next_version, + is_current=True, + ) + existing_id = self._find_current_by_hash(new_meta[CONTENT_HASH_KEY], new_meta) + if existing_id: + if existing_id == parent_id: + return existing_id + parent_meta[IS_CURRENT_KEY] = False + with self._conn: + self._conn.execute( + """ + UPDATE xmem_vectors + SET metadata = ?, updated_at = CURRENT_TIMESTAMP + WHERE namespace = ? AND id = ? + """, + (json.dumps(parent_meta), self._namespace, parent_id), + ) + return existing_id + + parent_meta[IS_CURRENT_KEY] = False + with self._conn: + self._conn.execute( + """ + UPDATE xmem_vectors + SET metadata = ?, updated_at = CURRENT_TIMESTAMP + WHERE namespace = ? AND id = ? + """, + (json.dumps(parent_meta), self._namespace, parent_id), + ) + self._conn.execute( + """ + INSERT INTO xmem_vectors(namespace, id, content, embedding, metadata) + VALUES (?, ?, ?, ?, ?) + """, + ( + self._namespace, + new_id, + text, + json.dumps([float(v) for v in embedding]), + json.dumps(new_meta), + ), + ) + return new_id + + def forget( + self, + ids: List[str], + reason: Optional[str] = None, + hard_delete: bool = False, + ) -> bool: + """Soft-forget memories by default, preserving audit history.""" + + if hard_delete: + return self.delete(ids) + if not ids: + return True + + placeholders = ",".join("?" for _ in ids) + rows = self._conn.execute( + f"SELECT id, metadata FROM xmem_vectors " + f"WHERE namespace = ? AND id IN ({placeholders})", + [self._namespace, *ids], + ).fetchall() + + now = utc_now_iso() + updates = [] + for row in rows: + meta = json.loads(row["metadata"] or "{}") + meta[IS_CURRENT_KEY] = False + meta[FORGOTTEN_AT_KEY] = now + meta[FORGET_REASON_KEY] = reason + updates.append((json.dumps(meta), self._namespace, row["id"])) + + if updates: + self._conn.executemany( + "UPDATE xmem_vectors SET metadata = ?, updated_at = CURRENT_TIMESTAMP " + "WHERE namespace = ? AND id = ?", + updates, + ) + self._conn.commit() + return True + def delete(self, ids: List[str]) -> bool: if not ids: return True @@ -239,10 +383,41 @@ def search_by_metadata( results: List[SearchResult] = [] for row in rows: meta = json.loads(row["metadata"] or "{}") - if _metadata_matches(meta, filters): + if is_retrievable_memory(meta) and _metadata_matches(meta, filters): results.append(SearchResult(id=row["id"], content=row["content"], score=1.0, metadata=meta)) return results[:top_k] + def _find_current_by_hash( + self, + content_hash: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> Optional[str]: + clauses = [ + "namespace = ?", + f"json_extract(metadata, '$.{CONTENT_HASH_KEY}') = ?", + f"json_extract(metadata, '$.{IS_CURRENT_KEY}') = 1", + f"json_extract(metadata, '$.{FORGOTTEN_AT_KEY}') IS NULL", + ] + params: List[Any] = [self._namespace, content_hash] + scope = { + key: (metadata or {}).get(key) + for key in _DEDUP_SCOPE_KEYS + if (metadata or {}).get(key) is not None + } + if scope: + for key, value in scope.items(): + clauses.append(f"json_extract(metadata, '$.{key}') = ?") + params.append(value) + else: + for key in _DEDUP_SCOPE_KEYS: + clauses.append(f"json_type(metadata, '$.{key}') IS NULL") + + row = self._conn.execute( + f"SELECT id FROM xmem_vectors WHERE {' AND '.join(clauses)} LIMIT 1", + params, + ).fetchone() + return row["id"] if row else None + async def search_by_text( self, query_text: str, diff --git a/src/storage/memory_lifecycle.py b/src/storage/memory_lifecycle.py new file mode 100644 index 0000000..f790583 --- /dev/null +++ b/src/storage/memory_lifecycle.py @@ -0,0 +1,63 @@ +"""Memory lifecycle metadata helpers. + +These helpers keep duplicate detection, version lineage, and soft-forget +metadata consistent across vector-store implementations. +""" + +from __future__ import annotations + +import hashlib +import re +from datetime import datetime, timezone +from typing import Any, Dict, Optional + +CONTENT_HASH_KEY = "content_hash" +PARENT_MEMORY_ID_KEY = "parent_memory_id" +VERSION_KEY = "version" +IS_CURRENT_KEY = "is_current" +FORGOTTEN_AT_KEY = "forgotten_at" +FORGET_REASON_KEY = "forget_reason" + + +def normalize_memory_content(content: str) -> str: + """Normalize memory text before hashing to catch whitespace-only duplicates.""" + + return re.sub(r"\s+", " ", content.strip()).casefold() + + +def compute_memory_hash(content: str) -> str: + """Return a stable SHA-256 digest for normalized memory content.""" + + normalized = normalize_memory_content(content) + return hashlib.sha256(normalized.encode("utf-8")).hexdigest() + + +def utc_now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def build_lifecycle_metadata( + content: str, + metadata: Optional[Dict[str, Any]] = None, + *, + parent_memory_id: Optional[str] = None, + version: int = 1, + is_current: bool = True, +) -> Dict[str, Any]: + """Merge caller metadata with lifecycle fields without losing custom keys.""" + + merged = dict(metadata or {}) + merged[CONTENT_HASH_KEY] = compute_memory_hash(content) + merged[PARENT_MEMORY_ID_KEY] = parent_memory_id + merged[VERSION_KEY] = version + merged[IS_CURRENT_KEY] = is_current + merged[FORGOTTEN_AT_KEY] = None + merged[FORGET_REASON_KEY] = None + return merged + + +def is_retrievable_memory(metadata: Optional[Dict[str, Any]]) -> bool: + """Return False for superseded or soft-forgotten memory records.""" + + meta = metadata or {} + return meta.get(IS_CURRENT_KEY, True) is not False and not meta.get(FORGOTTEN_AT_KEY) diff --git a/tests/unit/test_memory_lifecycle.py b/tests/unit/test_memory_lifecycle.py new file mode 100644 index 0000000..0e8a766 --- /dev/null +++ b/tests/unit/test_memory_lifecycle.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +from src.storage.local import SQLiteVectorStore +from src.storage.memory_lifecycle import ( + CONTENT_HASH_KEY, + FORGET_REASON_KEY, + FORGOTTEN_AT_KEY, + IS_CURRENT_KEY, + PARENT_MEMORY_ID_KEY, + VERSION_KEY, + compute_memory_hash, +) + + +def _store(tmp_path): + return SQLiteVectorStore( + path=str(tmp_path / "vectors.sqlite3"), + namespace="test", + dimension=3, + ) + + +def test_sqlite_add_reuses_current_memory_with_same_normalized_hash(tmp_path): + store = _store(tmp_path) + + first_ids = store.add( + ["Remember that Alice likes XMem."], + [[1.0, 0.0, 0.0]], + ids=["memory-1"], + metadata=[{"user_id": "alice"}], + ) + duplicate_ids = store.add( + [" remember THAT alice likes xmem. "], + [[0.0, 1.0, 0.0]], + ids=["memory-duplicate"], + metadata=[{"user_id": "alice"}], + ) + + assert first_ids == ["memory-1"] + assert duplicate_ids == ["memory-1"] + assert store.search_by_metadata({"user_id": "alice"}, top_k=10)[0].id == "memory-1" + + stored = store.get(["memory-1"])[0] + assert stored["metadata"][CONTENT_HASH_KEY] == compute_memory_hash("Remember that Alice likes XMem.") + assert stored["metadata"][VERSION_KEY] == 1 + assert stored["metadata"][IS_CURRENT_KEY] is True + + +def test_sqlite_hash_dedup_is_scoped_by_user_id(tmp_path): + store = _store(tmp_path) + + alice_ids = store.add( + ["Shared wording."], + [[1.0, 0.0, 0.0]], + ids=["alice-memory"], + metadata=[{"user_id": "alice"}], + ) + bob_ids = store.add( + ["shared wording."], + [[0.0, 1.0, 0.0]], + ids=["bob-memory"], + metadata=[{"user_id": "bob"}], + ) + + assert alice_ids == ["alice-memory"] + assert bob_ids == ["bob-memory"] + assert [r.id for r in store.search_by_metadata({"user_id": "alice"}, top_k=10)] == ["alice-memory"] + assert [r.id for r in store.search_by_metadata({"user_id": "bob"}, top_k=10)] == ["bob-memory"] + + +def test_sqlite_update_rejects_current_hash_collision(tmp_path): + store = _store(tmp_path) + store.add( + ["Alice works at XMem."], + [[1.0, 0.0, 0.0]], + ids=["profile-1"], + metadata=[{"user_id": "alice"}], + ) + store.add( + ["Alice works at XortexAI."], + [[0.0, 1.0, 0.0]], + ids=["profile-2"], + metadata=[{"user_id": "alice"}], + ) + + assert store.update("profile-1", text="Alice works at XortexAI.") is False + + visible = store.search_by_metadata({"user_id": "alice"}, top_k=10) + assert {result.id for result in visible} == {"profile-1", "profile-2"} + assert store.get(["profile-1"])[0]["content"] == "Alice works at XMem." + + +def test_lifecycle_fields_cannot_be_overridden_by_caller_metadata(tmp_path): + store = _store(tmp_path) + + ids = store.add( + ["Visible memory."], + [[1.0, 0.0, 0.0]], + ids=["visible-1"], + metadata=[{ + "user_id": "alice", + CONTENT_HASH_KEY: "caller-hash", + IS_CURRENT_KEY: False, + FORGOTTEN_AT_KEY: "2024-01-01T00:00:00+00:00", + }], + ) + + assert ids == ["visible-1"] + stored = store.get(["visible-1"])[0] + assert stored["metadata"][CONTENT_HASH_KEY] == compute_memory_hash("Visible memory.") + assert stored["metadata"][IS_CURRENT_KEY] is True + assert stored["metadata"][FORGOTTEN_AT_KEY] is None + assert store.search_by_metadata({"user_id": "alice"}, top_k=10)[0].id == "visible-1" + + +def test_sqlite_add_version_supersedes_parent_but_keeps_history(tmp_path): + store = _store(tmp_path) + store.add( + ["Alice works at XMem."], + [[1.0, 0.0, 0.0]], + ids=["profile-1"], + metadata=[{"user_id": "alice", "domain": "profile"}], + ) + + version_id = store.add_version( + "profile-1", + "Alice works at XortexAI.", + [0.0, 1.0, 0.0], + id="profile-2", + metadata={"user_id": "alice", "domain": "profile"}, + ) + + assert version_id == "profile-2" + parent = store.get(["profile-1"])[0] + child = store.get(["profile-2"])[0] + assert parent["metadata"][IS_CURRENT_KEY] is False + assert child["metadata"][PARENT_MEMORY_ID_KEY] == "profile-1" + assert child["metadata"][VERSION_KEY] == 2 + + visible = store.search_by_metadata({"user_id": "alice"}, top_k=10) + assert [result.id for result in visible] == ["profile-2"] + + +def test_sqlite_add_version_duplicate_supersedes_parent(tmp_path): + store = _store(tmp_path) + store.add( + ["Alice works at XMem."], + [[1.0, 0.0, 0.0]], + ids=["profile-1"], + metadata=[{"user_id": "alice", "domain": "profile"}], + ) + store.add( + ["Alice works at XortexAI."], + [[0.0, 1.0, 0.0]], + ids=["profile-existing"], + metadata=[{"user_id": "alice", "domain": "profile"}], + ) + + version_id = store.add_version( + "profile-1", + "Alice works at XortexAI.", + [0.0, 0.0, 1.0], + id="profile-2", + metadata={"user_id": "alice", "domain": "profile"}, + ) + + assert version_id == "profile-existing" + assert store.get(["profile-1"])[0]["metadata"][IS_CURRENT_KEY] is False + visible = store.search_by_metadata({"user_id": "alice"}, top_k=10) + assert [result.id for result in visible] == ["profile-existing"] + + +def test_sqlite_add_version_rejects_forgotten_parent(tmp_path): + store = _store(tmp_path) + store.add( + ["Alice revoked this memory."], + [[1.0, 0.0, 0.0]], + ids=["profile-1"], + metadata=[{"user_id": "alice", "domain": "profile"}], + ) + store.forget(["profile-1"], reason="user requested deletion") + + version_id = store.add_version( + "profile-1", + "Alice revoked this memory but changed.", + [0.0, 1.0, 0.0], + id="profile-2", + metadata={"user_id": "alice", "domain": "profile"}, + ) + + assert version_id is None + assert store.get(["profile-2"]) == [] + assert store.search_by_metadata({"user_id": "alice"}, top_k=10) == [] + + +def test_sqlite_add_version_rejects_superseded_parent(tmp_path): + store = _store(tmp_path) + store.add( + ["Alice works at XMem."], + [[1.0, 0.0, 0.0]], + ids=["profile-1"], + metadata=[{"user_id": "alice", "domain": "profile"}], + ) + store.add_version( + "profile-1", + "Alice works at XortexAI.", + [0.0, 1.0, 0.0], + id="profile-2", + metadata={"user_id": "alice", "domain": "profile"}, + ) + + version_id = store.add_version( + "profile-1", + "Alice works somewhere else.", + [0.0, 0.0, 1.0], + id="profile-3", + metadata={"user_id": "alice", "domain": "profile"}, + ) + + assert version_id is None + assert store.get(["profile-3"]) == [] + visible = store.search_by_metadata({"user_id": "alice"}, top_k=10) + assert [result.id for result in visible] == ["profile-2"] + + +def test_sqlite_forget_soft_deletes_memory_from_retrieval(tmp_path): + store = _store(tmp_path) + store.add( + ["Alice's temporary preference."], + [[1.0, 0.0, 0.0]], + ids=["temp-1"], + metadata=[{"user_id": "alice"}], + ) + + assert store.forget(["temp-1"], reason="user requested deletion") is True + + assert store.search_by_metadata({"user_id": "alice"}, top_k=10) == [] + assert store.search([1.0, 0.0, 0.0], top_k=10) == [] + + stored = store.get(["temp-1"])[0] + assert stored["metadata"][IS_CURRENT_KEY] is False + assert stored["metadata"][FORGOTTEN_AT_KEY] + assert stored["metadata"][FORGET_REASON_KEY] == "user requested deletion"