diff --git a/drivers/python/age/__init__.py b/drivers/python/age/__init__.py index 685f0fe74..b84bd162e 100644 --- a/drivers/python/age/__init__.py +++ b/drivers/python/age/__init__.py @@ -16,6 +16,7 @@ import psycopg.conninfo as conninfo from . import age from .age import * +from .age import AgeLoader, ClientCursor, configure_connection from .models import * from .builder import ResultHandler, DummyResultHandler, parseAgeValue, newResultHandler from . import VERSION diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py index fad1f27b1..e9e8d003d 100644 --- a/drivers/python/age/age.py +++ b/drivers/python/age/age.py @@ -162,6 +162,71 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals if graphName != None: checkGraphCreated(conn, graphName) + +def configure_connection( + conn: psycopg.connection, + graph_name: str | None = None, + load: bool = False, + load_from_plugins: bool = False, +) -> None: + """Register AGE agtype adapters on an existing connection. + + This enables use of AGE with externally-managed connections, such as + those from psycopg_pool.ConnectionPool. By default the function does + **not** execute ``LOAD 'age'``, making it safe for managed PostgreSQL + services (Azure, AWS RDS) where the extension is pre-loaded via + ``shared_preload_libraries``. + + Performs: + - ``SET search_path`` to include ``ag_catalog`` + - Fetches agtype OIDs and registers ``AgeLoader`` + - Optionally loads the AGE extension (``load=True``) + - Optionally checks/creates the graph + + Args: + conn: An existing psycopg connection. + graph_name: Optional graph name to check/create. + load: If True, execute ``LOAD 'age'`` (or the plugins path). + Default False — suitable for environments where AGE is + already loaded. + load_from_plugins: If True (and ``load=True``), use + ``LOAD '$libdir/plugins/age'`` instead of ``LOAD 'age'``. + + Raises: + ValueError: If ``load_from_plugins=True`` but ``load=False``. + AgeNotSet: If the agtype type is not found in the database. + """ + if load_from_plugins and not load: + raise ValueError( + "load_from_plugins=True requires load=True. " + "Set load=True to enable extension loading." + ) + + with conn.cursor() as cursor: + if load: + if load_from_plugins: + cursor.execute("LOAD '$libdir/plugins/age';") + else: + cursor.execute("LOAD 'age';") + + cursor.execute("SET search_path = ag_catalog, '$user', public;") + + ag_info = TypeInfo.fetch(conn, 'agtype') + + if not ag_info: + raise AgeNotSet( + "AGE agtype type not found. Ensure the AGE extension is " + "installed and loaded in the current database. " + "Run CREATE EXTENSION age; first." + ) + + conn.adapters.register_loader(ag_info.oid, AgeLoader) + conn.adapters.register_loader(ag_info.array_oid, AgeLoader) + + if graph_name is not None: + checkGraphCreated(conn, graph_name) + + # Create the graph, if it does not exist def checkGraphCreated(conn:psycopg.connection, graphName:str): validate_graph_name(graphName) diff --git a/drivers/python/age/models.py b/drivers/python/age/models.py index 6d9095485..62215c160 100644 --- a/drivers/python/age/models.py +++ b/drivers/python/age/models.py @@ -118,6 +118,20 @@ def toJson(self) -> str: return buf.getvalue() + def to_dict(self) -> list: + # AGObj elements are recursively converted; JSON-native types + # (dict, list, str, int, float, bool, None) pass through unchanged. + # Non-serializable objects fall back to str() as a safety net. + result = [] + for e in (self.entities or []): + if isinstance(e, AGObj): + result.append(e.to_dict()) + elif isinstance(e, (dict, list, str, int, float, bool, type(None))): + result.append(e) + else: + result.append(str(e)) + return result + @@ -146,6 +160,18 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.toString() + def to_dict(self) -> dict: + """Return a plain dict suitable for JSON serialization. + + Properties are shallow-copied; nested mutable values will share + references with the original Vertex. + """ + return { + "id": self.id, + "label": self.label, + "properties": dict(self.properties) if self.properties else {}, + } + def toString(self) -> str: return nodeToString(self) @@ -186,6 +212,20 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.toString() + def to_dict(self) -> dict: + """Return a plain dict suitable for JSON serialization. + + Properties are shallow-copied; nested mutable values will share + references with the original Edge. + """ + return { + "id": self.id, + "label": self.label, + "start_id": self.start_id, + "end_id": self.end_id, + "properties": dict(self.properties) if self.properties else {}, + } + def extraStrFormat(node, buf): if node.start_id != None: buf.write(", start_id:") diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index f904fb9e3..169459889 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -14,8 +14,9 @@ # under the License. import json -from age.models import Vertex +from age.models import Vertex, Edge, Path import unittest +import unittest.mock import decimal import age import argparse @@ -28,6 +29,160 @@ TEST_GRAPH_NAME = "test_graph" +class TestModelToDict(unittest.TestCase): + """Unit tests for Vertex/Edge/Path to_dict() — no DB required.""" + + def test_vertex_to_dict(self): + v = Vertex(id=123, label="Person", properties={"name": "Alice", "age": 30}) + d = v.to_dict() + self.assertEqual(d["id"], 123) + self.assertEqual(d["label"], "Person") + self.assertEqual(d["properties"], {"name": "Alice", "age": 30}) + # Verify it's a plain dict (JSON-serializable) + json_str = json.dumps(d) + self.assertIn("Alice", json_str) + + def test_vertex_to_dict_empty_properties(self): + v = Vertex(id=1, label="Empty", properties=None) + d = v.to_dict() + self.assertEqual(d["properties"], {}) + + def test_edge_to_dict(self): + e = Edge(id=456, label="KNOWS", properties={"since": 2020}) + e.start_id = 123 + e.end_id = 789 + d = e.to_dict() + self.assertEqual(d["id"], 456) + self.assertEqual(d["label"], "KNOWS") + self.assertEqual(d["start_id"], 123) + self.assertEqual(d["end_id"], 789) + self.assertEqual(d["properties"], {"since": 2020}) + json_str = json.dumps(d) + self.assertIn("KNOWS", json_str) + + def test_path_to_dict(self): + v1 = Vertex(id=1, label="A", properties={"name": "start"}) + e = Edge(id=10, label="r", properties={"w": 1}) + e.start_id = 1 + e.end_id = 2 + v2 = Vertex(id=2, label="B", properties={"name": "end"}) + p = Path([v1, e, v2]) + d = p.to_dict() + self.assertEqual(len(d), 3) + self.assertEqual(d[0]["label"], "A") + self.assertEqual(d[1]["label"], "r") + self.assertEqual(d[1]["start_id"], 1) + self.assertEqual(d[2]["label"], "B") + # Verify the whole path is JSON-serializable + json_str = json.dumps(d) + self.assertIn("start", json_str) + + def test_vertex_to_dict_is_plain_dict(self): + """to_dict() returns standard dict, not a model object.""" + v = Vertex(id=1, label="X", properties={"k": "v"}) + d = v.to_dict() + self.assertIsInstance(d, dict) + self.assertIsInstance(d["properties"], dict) + + +class TestPublicImports(unittest.TestCase): + """Verify that public API symbols are importable without type: ignore.""" + + def test_import_configure_connection(self): + from age import configure_connection + self.assertTrue(callable(configure_connection)) + + def test_import_age_loader(self): + from age import AgeLoader + self.assertIsNotNone(AgeLoader) + + def test_import_client_cursor(self): + from age import ClientCursor + self.assertIsNotNone(ClientCursor) + + +class TestConfigureConnection(unittest.TestCase): + """Unit tests for configure_connection() — no DB required.""" + + def _make_mock_conn(self): + mock_conn = unittest.mock.MagicMock() + mock_cursor = unittest.mock.MagicMock() + mock_conn.cursor.return_value.__enter__ = unittest.mock.Mock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = unittest.mock.Mock(return_value=False) + mock_conn.adapters = unittest.mock.MagicMock() + mock_type_info = unittest.mock.MagicMock() + mock_type_info.oid = 1 + mock_type_info.array_oid = 2 + return mock_conn, mock_cursor, mock_type_info + + def test_default_does_not_load(self): + """By default, configure_connection should NOT execute LOAD.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn) + mock_cursor.execute.assert_called_once_with( + "SET search_path = ag_catalog, '$user', public;" + ) + + def test_load_true_executes_load(self): + """When load=True, LOAD 'age' must be executed.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn, load=True) + mock_cursor.execute.assert_any_call("LOAD 'age';") + + def test_load_from_plugins(self): + """When load=True and load_from_plugins=True, use plugins path.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn, load=True, load_from_plugins=True) + mock_cursor.execute.assert_any_call("LOAD '$libdir/plugins/age';") + + def test_load_from_plugins_without_load_raises(self): + """load_from_plugins=True without load=True must raise ValueError.""" + mock_conn, _, _ = self._make_mock_conn() + with self.assertRaises(ValueError): + age.age.configure_connection(mock_conn, load_from_plugins=True) + + def test_always_sets_search_path(self): + """search_path must always be set regardless of load parameter.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn) + mock_cursor.execute.assert_any_call( + "SET search_path = ag_catalog, '$user', public;" + ) + + def test_registers_agtype_adapters(self): + """AgeLoader must be registered for agtype OIDs.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn) + mock_conn.adapters.register_loader.assert_any_call(1, age.age.AgeLoader) + mock_conn.adapters.register_loader.assert_any_call(2, age.age.AgeLoader) + + def test_graph_name_triggers_check(self): + """When graph_name is provided, checkGraphCreated must be called.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated") as mock_check: + age.age.configure_connection(mock_conn, graph_name="my_graph") + mock_check.assert_called_once_with(mock_conn, "my_graph") + + def test_age_not_set_when_type_info_is_none(self): + """AgeNotSet must be raised when TypeInfo.fetch returns None.""" + from age.exceptions import AgeNotSet + mock_conn, _, _ = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=None): + with self.assertRaises(AgeNotSet): + age.age.configure_connection(mock_conn) + + class TestAgeBasic(unittest.TestCase): ag = None args: argparse.Namespace = argparse.Namespace( @@ -485,6 +640,12 @@ def testSerialization(self): args = parser.parse_args() suite = unittest.TestSuite() + # Unit tests (no DB required) + loader = unittest.TestLoader() + suite.addTests(loader.loadTestsFromTestCase(TestModelToDict)) + suite.addTests(loader.loadTestsFromTestCase(TestPublicImports)) + suite.addTests(loader.loadTestsFromTestCase(TestConfigureConnection)) + # Integration tests (require DB) suite.addTest(TestAgeBasic("testExec")) suite.addTest(TestAgeBasic("testQuery")) suite.addTest(TestAgeBasic("testChangeData"))