-
Notifications
You must be signed in to change notification settings - Fork 482
feat(python-driver): add public API for connection pooling and model dict conversion #2374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2a8a7c2
7d64707
1c1acd4
53ae0fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -162,6 +162,71 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals | |
| if graphName != None: | ||
| checkGraphCreated(conn, graphName) | ||
|
|
||
|
|
||
| def configure_connection( | ||
| conn: psycopg.connection, | ||
| graph_name: str | None = None, | ||
| load: bool = False, | ||
| load_from_plugins: bool = False, | ||
| ) -> None: | ||
| """Register AGE agtype adapters on an existing connection. | ||
|
|
||
| This enables use of AGE with externally-managed connections, such as | ||
| those from psycopg_pool.ConnectionPool. By default the function does | ||
| **not** execute ``LOAD 'age'``, making it safe for managed PostgreSQL | ||
| services (Azure, AWS RDS) where the extension is pre-loaded via | ||
| ``shared_preload_libraries``. | ||
|
|
||
| Performs: | ||
| - ``SET search_path`` to include ``ag_catalog`` | ||
| - Fetches agtype OIDs and registers ``AgeLoader`` | ||
| - Optionally loads the AGE extension (``load=True``) | ||
| - Optionally checks/creates the graph | ||
|
|
||
| Args: | ||
| conn: An existing psycopg connection. | ||
| graph_name: Optional graph name to check/create. | ||
| load: If True, execute ``LOAD 'age'`` (or the plugins path). | ||
| Default False — suitable for environments where AGE is | ||
| already loaded. | ||
| load_from_plugins: If True (and ``load=True``), use | ||
| ``LOAD '$libdir/plugins/age'`` instead of ``LOAD 'age'``. | ||
|
|
||
| Raises: | ||
| ValueError: If ``load_from_plugins=True`` but ``load=False``. | ||
| AgeNotSet: If the agtype type is not found in the database. | ||
| """ | ||
| if load_from_plugins and not load: | ||
| raise ValueError( | ||
| "load_from_plugins=True requires load=True. " | ||
| "Set load=True to enable extension loading." | ||
| ) | ||
|
|
||
| with conn.cursor() as cursor: | ||
| if load: | ||
| if load_from_plugins: | ||
| cursor.execute("LOAD '$libdir/plugins/age';") | ||
| else: | ||
| cursor.execute("LOAD 'age';") | ||
|
|
||
| cursor.execute("SET search_path = ag_catalog, '$user', public;") | ||
|
|
||
| ag_info = TypeInfo.fetch(conn, 'agtype') | ||
|
|
||
| if not ag_info: | ||
| raise AgeNotSet( | ||
| "AGE agtype type not found. Ensure the AGE extension is " | ||
| "installed and loaded in the current database. " | ||
| "Run CREATE EXTENSION age; first." | ||
| ) | ||
|
|
||
| conn.adapters.register_loader(ag_info.oid, AgeLoader) | ||
| conn.adapters.register_loader(ag_info.array_oid, AgeLoader) | ||
|
|
||
| if graph_name is not None: | ||
| checkGraphCreated(conn, graph_name) | ||
|
|
||
|
Comment on lines
+166
to
+228
|
||
|
|
||
| # Create the graph, if it does not exist | ||
| def checkGraphCreated(conn:psycopg.connection, graphName:str): | ||
| validate_graph_name(graphName) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,8 +14,9 @@ | |
| # under the License. | ||
| import json | ||
|
|
||
| from age.models import Vertex | ||
| from age.models import Vertex, Edge, Path | ||
| import unittest | ||
| import unittest.mock | ||
| import decimal | ||
| import age | ||
| import argparse | ||
|
|
@@ -28,6 +29,160 @@ | |
| TEST_GRAPH_NAME = "test_graph" | ||
|
|
||
|
|
||
| class TestModelToDict(unittest.TestCase): | ||
| """Unit tests for Vertex/Edge/Path to_dict() — no DB required.""" | ||
|
|
||
| def test_vertex_to_dict(self): | ||
| v = Vertex(id=123, label="Person", properties={"name": "Alice", "age": 30}) | ||
| d = v.to_dict() | ||
| self.assertEqual(d["id"], 123) | ||
| self.assertEqual(d["label"], "Person") | ||
| self.assertEqual(d["properties"], {"name": "Alice", "age": 30}) | ||
| # Verify it's a plain dict (JSON-serializable) | ||
| json_str = json.dumps(d) | ||
| self.assertIn("Alice", json_str) | ||
|
|
||
| def test_vertex_to_dict_empty_properties(self): | ||
| v = Vertex(id=1, label="Empty", properties=None) | ||
| d = v.to_dict() | ||
| self.assertEqual(d["properties"], {}) | ||
|
|
||
| def test_edge_to_dict(self): | ||
| e = Edge(id=456, label="KNOWS", properties={"since": 2020}) | ||
| e.start_id = 123 | ||
| e.end_id = 789 | ||
| d = e.to_dict() | ||
| self.assertEqual(d["id"], 456) | ||
| self.assertEqual(d["label"], "KNOWS") | ||
| self.assertEqual(d["start_id"], 123) | ||
| self.assertEqual(d["end_id"], 789) | ||
| self.assertEqual(d["properties"], {"since": 2020}) | ||
| json_str = json.dumps(d) | ||
| self.assertIn("KNOWS", json_str) | ||
|
|
||
| def test_path_to_dict(self): | ||
| v1 = Vertex(id=1, label="A", properties={"name": "start"}) | ||
| e = Edge(id=10, label="r", properties={"w": 1}) | ||
| e.start_id = 1 | ||
| e.end_id = 2 | ||
| v2 = Vertex(id=2, label="B", properties={"name": "end"}) | ||
| p = Path([v1, e, v2]) | ||
| d = p.to_dict() | ||
| self.assertEqual(len(d), 3) | ||
| self.assertEqual(d[0]["label"], "A") | ||
| self.assertEqual(d[1]["label"], "r") | ||
| self.assertEqual(d[1]["start_id"], 1) | ||
| self.assertEqual(d[2]["label"], "B") | ||
| # Verify the whole path is JSON-serializable | ||
| json_str = json.dumps(d) | ||
| self.assertIn("start", json_str) | ||
|
|
||
| def test_vertex_to_dict_is_plain_dict(self): | ||
| """to_dict() returns standard dict, not a model object.""" | ||
| v = Vertex(id=1, label="X", properties={"k": "v"}) | ||
| d = v.to_dict() | ||
| self.assertIsInstance(d, dict) | ||
| self.assertIsInstance(d["properties"], dict) | ||
|
|
||
|
|
||
| class TestPublicImports(unittest.TestCase): | ||
| """Verify that public API symbols are importable without type: ignore.""" | ||
|
|
||
| def test_import_configure_connection(self): | ||
| from age import configure_connection | ||
| self.assertTrue(callable(configure_connection)) | ||
|
|
||
| def test_import_age_loader(self): | ||
| from age import AgeLoader | ||
| self.assertIsNotNone(AgeLoader) | ||
|
|
||
| def test_import_client_cursor(self): | ||
| from age import ClientCursor | ||
| self.assertIsNotNone(ClientCursor) | ||
|
|
||
|
|
||
| class TestConfigureConnection(unittest.TestCase): | ||
| """Unit tests for configure_connection() — no DB required.""" | ||
|
|
||
|
Comment on lines
+32
to
+106
|
||
| def _make_mock_conn(self): | ||
| mock_conn = unittest.mock.MagicMock() | ||
| mock_cursor = unittest.mock.MagicMock() | ||
| mock_conn.cursor.return_value.__enter__ = unittest.mock.Mock(return_value=mock_cursor) | ||
| mock_conn.cursor.return_value.__exit__ = unittest.mock.Mock(return_value=False) | ||
| mock_conn.adapters = unittest.mock.MagicMock() | ||
| mock_type_info = unittest.mock.MagicMock() | ||
| mock_type_info.oid = 1 | ||
| mock_type_info.array_oid = 2 | ||
| return mock_conn, mock_cursor, mock_type_info | ||
|
|
||
| def test_default_does_not_load(self): | ||
| """By default, configure_connection should NOT execute LOAD.""" | ||
| mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() | ||
| with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ | ||
| unittest.mock.patch("age.age.checkGraphCreated"): | ||
| age.age.configure_connection(mock_conn) | ||
| mock_cursor.execute.assert_called_once_with( | ||
| "SET search_path = ag_catalog, '$user', public;" | ||
| ) | ||
|
|
||
| def test_load_true_executes_load(self): | ||
| """When load=True, LOAD 'age' must be executed.""" | ||
| mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() | ||
| with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ | ||
| unittest.mock.patch("age.age.checkGraphCreated"): | ||
| age.age.configure_connection(mock_conn, load=True) | ||
| mock_cursor.execute.assert_any_call("LOAD 'age';") | ||
|
|
||
| def test_load_from_plugins(self): | ||
| """When load=True and load_from_plugins=True, use plugins path.""" | ||
| mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() | ||
| with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ | ||
| unittest.mock.patch("age.age.checkGraphCreated"): | ||
| age.age.configure_connection(mock_conn, load=True, load_from_plugins=True) | ||
| mock_cursor.execute.assert_any_call("LOAD '$libdir/plugins/age';") | ||
|
|
||
| def test_load_from_plugins_without_load_raises(self): | ||
| """load_from_plugins=True without load=True must raise ValueError.""" | ||
| mock_conn, _, _ = self._make_mock_conn() | ||
| with self.assertRaises(ValueError): | ||
| age.age.configure_connection(mock_conn, load_from_plugins=True) | ||
|
|
||
| def test_always_sets_search_path(self): | ||
| """search_path must always be set regardless of load parameter.""" | ||
| mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() | ||
| with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ | ||
| unittest.mock.patch("age.age.checkGraphCreated"): | ||
| age.age.configure_connection(mock_conn) | ||
| mock_cursor.execute.assert_any_call( | ||
| "SET search_path = ag_catalog, '$user', public;" | ||
| ) | ||
|
|
||
| def test_registers_agtype_adapters(self): | ||
| """AgeLoader must be registered for agtype OIDs.""" | ||
| mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() | ||
| with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ | ||
| unittest.mock.patch("age.age.checkGraphCreated"): | ||
| age.age.configure_connection(mock_conn) | ||
| mock_conn.adapters.register_loader.assert_any_call(1, age.age.AgeLoader) | ||
| mock_conn.adapters.register_loader.assert_any_call(2, age.age.AgeLoader) | ||
|
|
||
| def test_graph_name_triggers_check(self): | ||
| """When graph_name is provided, checkGraphCreated must be called.""" | ||
| mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() | ||
| with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ | ||
| unittest.mock.patch("age.age.checkGraphCreated") as mock_check: | ||
| age.age.configure_connection(mock_conn, graph_name="my_graph") | ||
| mock_check.assert_called_once_with(mock_conn, "my_graph") | ||
|
|
||
| def test_age_not_set_when_type_info_is_none(self): | ||
| """AgeNotSet must be raised when TypeInfo.fetch returns None.""" | ||
| from age.exceptions import AgeNotSet | ||
| mock_conn, _, _ = self._make_mock_conn() | ||
| with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=None): | ||
| with self.assertRaises(AgeNotSet): | ||
| age.age.configure_connection(mock_conn) | ||
|
|
||
|
|
||
| class TestAgeBasic(unittest.TestCase): | ||
| ag = None | ||
| args: argparse.Namespace = argparse.Namespace( | ||
|
|
@@ -485,6 +640,12 @@ def testSerialization(self): | |
|
|
||
| args = parser.parse_args() | ||
| suite = unittest.TestSuite() | ||
| # Unit tests (no DB required) | ||
| loader = unittest.TestLoader() | ||
| suite.addTests(loader.loadTestsFromTestCase(TestModelToDict)) | ||
| suite.addTests(loader.loadTestsFromTestCase(TestPublicImports)) | ||
| suite.addTests(loader.loadTestsFromTestCase(TestConfigureConnection)) | ||
| # Integration tests (require DB) | ||
| suite.addTest(TestAgeBasic("testExec")) | ||
| suite.addTest(TestAgeBasic("testQuery")) | ||
| suite.addTest(TestAgeBasic("testChangeData")) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description and issue summary mention a
skip_loadparameter, but the implemented public API usesload: bool = Falseinstead. Please align the PR description/docs with the actual signature (or rename the parameter) to avoid confusion for users upgrading based on the PR notes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intentional —
configure_connection(load=False)uses positive opt-in semantics (load=Trueto enable loading), unlikesetUpAge(skip_load=True)which uses opt-out. Theloadparameter was chosen because the default use case forconfigure_connectionis connection pooling where AGE is already loaded. The docstring documents this clearly.