Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions drivers/python/age/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import psycopg.conninfo as conninfo
from . import age
from .age import *
from .age import AgeLoader, ClientCursor, configure_connection
from .models import *
from .builder import ResultHandler, DummyResultHandler, parseAgeValue, newResultHandler
from . import VERSION
Expand Down
65 changes: 65 additions & 0 deletions drivers/python/age/age.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,71 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals
if graphName != None:
checkGraphCreated(conn, graphName)


def configure_connection(
conn: psycopg.connection,
graph_name: str | None = None,
load: bool = False,
load_from_plugins: bool = False,
) -> None:
Comment on lines +166 to +171
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description and issue summary mention a skip_load parameter, but the implemented public API uses load: bool = False instead. Please align the PR description/docs with the actual signature (or rename the parameter) to avoid confusion for users upgrading based on the PR notes.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional — configure_connection(load=False) uses positive opt-in semantics (load=True to enable loading), unlike setUpAge(skip_load=True) which uses opt-out. The load parameter was chosen because the default use case for configure_connection is connection pooling where AGE is already loaded. The docstring documents this clearly.

"""Register AGE agtype adapters on an existing connection.

This enables use of AGE with externally-managed connections, such as
those from psycopg_pool.ConnectionPool. By default the function does
**not** execute ``LOAD 'age'``, making it safe for managed PostgreSQL
services (Azure, AWS RDS) where the extension is pre-loaded via
``shared_preload_libraries``.

Performs:
- ``SET search_path`` to include ``ag_catalog``
- Fetches agtype OIDs and registers ``AgeLoader``
- Optionally loads the AGE extension (``load=True``)
- Optionally checks/creates the graph

Args:
conn: An existing psycopg connection.
graph_name: Optional graph name to check/create.
load: If True, execute ``LOAD 'age'`` (or the plugins path).
Default False — suitable for environments where AGE is
already loaded.
load_from_plugins: If True (and ``load=True``), use
``LOAD '$libdir/plugins/age'`` instead of ``LOAD 'age'``.

Raises:
ValueError: If ``load_from_plugins=True`` but ``load=False``.
AgeNotSet: If the agtype type is not found in the database.
"""
if load_from_plugins and not load:
raise ValueError(
"load_from_plugins=True requires load=True. "
"Set load=True to enable extension loading."
)

with conn.cursor() as cursor:
if load:
if load_from_plugins:
cursor.execute("LOAD '$libdir/plugins/age';")
else:
cursor.execute("LOAD 'age';")

cursor.execute("SET search_path = ag_catalog, '$user', public;")

ag_info = TypeInfo.fetch(conn, 'agtype')

if not ag_info:
raise AgeNotSet(
"AGE agtype type not found. Ensure the AGE extension is "
"installed and loaded in the current database. "
"Run CREATE EXTENSION age; first."
)

conn.adapters.register_loader(ag_info.oid, AgeLoader)
conn.adapters.register_loader(ag_info.array_oid, AgeLoader)

if graph_name is not None:
checkGraphCreated(conn, graph_name)

Comment on lines +166 to +228
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

configure_connection() largely duplicates setUpAge() (LOAD/search_path/agtype TypeInfo fetch/loader registration/graph check). This duplication risks the two code paths drifting over time (e.g., changes to error handling or adapter registration). Consider refactoring so setUpAge() calls configure_connection(load=True, ...) (or vice versa) and share a single implementation.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid observation. Refactoring setUpAge() to delegate to configure_connection() (or vice versa) would reduce duplication, but it changes the behavior semantics of setUpAge which is the existing public API. Better addressed in a follow-up PR to avoid scope creep.


# Create the graph, if it does not exist
def checkGraphCreated(conn:psycopg.connection, graphName:str):
validate_graph_name(graphName)
Expand Down
40 changes: 40 additions & 0 deletions drivers/python/age/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ def toJson(self) -> str:

return buf.getvalue()

def to_dict(self) -> list:
# AGObj elements are recursively converted; JSON-native types
# (dict, list, str, int, float, bool, None) pass through unchanged.
# Non-serializable objects fall back to str() as a safety net.
result = []
for e in (self.entities or []):
if isinstance(e, AGObj):
result.append(e.to_dict())
elif isinstance(e, (dict, list, str, int, float, bool, type(None))):
result.append(e)
else:
result.append(str(e))
return result




Expand Down Expand Up @@ -146,6 +160,18 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.toString()

def to_dict(self) -> dict:
"""Return a plain dict suitable for JSON serialization.

Properties are shallow-copied; nested mutable values will share
references with the original Vertex.
"""
return {
"id": self.id,
"label": self.label,
"properties": dict(self.properties) if self.properties else {},
}

def toString(self) -> str:
return nodeToString(self)

Expand Down Expand Up @@ -186,6 +212,20 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.toString()

def to_dict(self) -> dict:
"""Return a plain dict suitable for JSON serialization.

Properties are shallow-copied; nested mutable values will share
references with the original Edge.
"""
return {
"id": self.id,
"label": self.label,
"start_id": self.start_id,
"end_id": self.end_id,
"properties": dict(self.properties) if self.properties else {},
}

def extraStrFormat(node, buf):
if node.start_id != None:
buf.write(", start_id:")
Expand Down
163 changes: 162 additions & 1 deletion drivers/python/test_age_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# under the License.
import json

from age.models import Vertex
from age.models import Vertex, Edge, Path
import unittest
import unittest.mock
import decimal
import age
import argparse
Expand All @@ -28,6 +29,160 @@
TEST_GRAPH_NAME = "test_graph"


class TestModelToDict(unittest.TestCase):
"""Unit tests for Vertex/Edge/Path to_dict() — no DB required."""

def test_vertex_to_dict(self):
v = Vertex(id=123, label="Person", properties={"name": "Alice", "age": 30})
d = v.to_dict()
self.assertEqual(d["id"], 123)
self.assertEqual(d["label"], "Person")
self.assertEqual(d["properties"], {"name": "Alice", "age": 30})
# Verify it's a plain dict (JSON-serializable)
json_str = json.dumps(d)
self.assertIn("Alice", json_str)

def test_vertex_to_dict_empty_properties(self):
v = Vertex(id=1, label="Empty", properties=None)
d = v.to_dict()
self.assertEqual(d["properties"], {})

def test_edge_to_dict(self):
e = Edge(id=456, label="KNOWS", properties={"since": 2020})
e.start_id = 123
e.end_id = 789
d = e.to_dict()
self.assertEqual(d["id"], 456)
self.assertEqual(d["label"], "KNOWS")
self.assertEqual(d["start_id"], 123)
self.assertEqual(d["end_id"], 789)
self.assertEqual(d["properties"], {"since": 2020})
json_str = json.dumps(d)
self.assertIn("KNOWS", json_str)

def test_path_to_dict(self):
v1 = Vertex(id=1, label="A", properties={"name": "start"})
e = Edge(id=10, label="r", properties={"w": 1})
e.start_id = 1
e.end_id = 2
v2 = Vertex(id=2, label="B", properties={"name": "end"})
p = Path([v1, e, v2])
d = p.to_dict()
self.assertEqual(len(d), 3)
self.assertEqual(d[0]["label"], "A")
self.assertEqual(d[1]["label"], "r")
self.assertEqual(d[1]["start_id"], 1)
self.assertEqual(d[2]["label"], "B")
# Verify the whole path is JSON-serializable
json_str = json.dumps(d)
self.assertIn("start", json_str)

def test_vertex_to_dict_is_plain_dict(self):
"""to_dict() returns standard dict, not a model object."""
v = Vertex(id=1, label="X", properties={"k": "v"})
d = v.to_dict()
self.assertIsInstance(d, dict)
self.assertIsInstance(d["properties"], dict)


class TestPublicImports(unittest.TestCase):
"""Verify that public API symbols are importable without type: ignore."""

def test_import_configure_connection(self):
from age import configure_connection
self.assertTrue(callable(configure_connection))

def test_import_age_loader(self):
from age import AgeLoader
self.assertIsNotNone(AgeLoader)

def test_import_client_cursor(self):
from age import ClientCursor
self.assertIsNotNone(ClientCursor)


class TestConfigureConnection(unittest.TestCase):
"""Unit tests for configure_connection() — no DB required."""

Comment on lines +32 to +106
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new unit test classes (TestModelToDict, TestPublicImports, TestConfigureConnection) won’t run in CI when executing python test_age_py.py ... because the __main__ block constructs a TestSuite that only includes TestAgeBasic tests. Consider adding these new tests to the explicit suite, or switching this file to unittest.main() (and moving the DB/integration tests behind an opt-in flag) so the new tests are actually executed.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — added TestModelToDict, TestPublicImports, and TestConfigureConnection to the __main__ suite via unittest.TestLoader().loadTestsFromTestCase(). They now run both via python -m unittest discovery and direct script execution.

def _make_mock_conn(self):
mock_conn = unittest.mock.MagicMock()
mock_cursor = unittest.mock.MagicMock()
mock_conn.cursor.return_value.__enter__ = unittest.mock.Mock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = unittest.mock.Mock(return_value=False)
mock_conn.adapters = unittest.mock.MagicMock()
mock_type_info = unittest.mock.MagicMock()
mock_type_info.oid = 1
mock_type_info.array_oid = 2
return mock_conn, mock_cursor, mock_type_info

def test_default_does_not_load(self):
"""By default, configure_connection should NOT execute LOAD."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.configure_connection(mock_conn)
mock_cursor.execute.assert_called_once_with(
"SET search_path = ag_catalog, '$user', public;"
)

def test_load_true_executes_load(self):
"""When load=True, LOAD 'age' must be executed."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.configure_connection(mock_conn, load=True)
mock_cursor.execute.assert_any_call("LOAD 'age';")

def test_load_from_plugins(self):
"""When load=True and load_from_plugins=True, use plugins path."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.configure_connection(mock_conn, load=True, load_from_plugins=True)
mock_cursor.execute.assert_any_call("LOAD '$libdir/plugins/age';")

def test_load_from_plugins_without_load_raises(self):
"""load_from_plugins=True without load=True must raise ValueError."""
mock_conn, _, _ = self._make_mock_conn()
with self.assertRaises(ValueError):
age.age.configure_connection(mock_conn, load_from_plugins=True)

def test_always_sets_search_path(self):
"""search_path must always be set regardless of load parameter."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.configure_connection(mock_conn)
mock_cursor.execute.assert_any_call(
"SET search_path = ag_catalog, '$user', public;"
)

def test_registers_agtype_adapters(self):
"""AgeLoader must be registered for agtype OIDs."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.configure_connection(mock_conn)
mock_conn.adapters.register_loader.assert_any_call(1, age.age.AgeLoader)
mock_conn.adapters.register_loader.assert_any_call(2, age.age.AgeLoader)

def test_graph_name_triggers_check(self):
"""When graph_name is provided, checkGraphCreated must be called."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated") as mock_check:
age.age.configure_connection(mock_conn, graph_name="my_graph")
mock_check.assert_called_once_with(mock_conn, "my_graph")

def test_age_not_set_when_type_info_is_none(self):
"""AgeNotSet must be raised when TypeInfo.fetch returns None."""
from age.exceptions import AgeNotSet
mock_conn, _, _ = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=None):
with self.assertRaises(AgeNotSet):
age.age.configure_connection(mock_conn)


class TestAgeBasic(unittest.TestCase):
ag = None
args: argparse.Namespace = argparse.Namespace(
Expand Down Expand Up @@ -485,6 +640,12 @@ def testSerialization(self):

args = parser.parse_args()
suite = unittest.TestSuite()
# Unit tests (no DB required)
loader = unittest.TestLoader()
suite.addTests(loader.loadTestsFromTestCase(TestModelToDict))
suite.addTests(loader.loadTestsFromTestCase(TestPublicImports))
suite.addTests(loader.loadTestsFromTestCase(TestConfigureConnection))
# Integration tests (require DB)
suite.addTest(TestAgeBasic("testExec"))
suite.addTest(TestAgeBasic("testQuery"))
suite.addTest(TestAgeBasic("testChangeData"))
Expand Down