From c2052a9978e8aa1b5c82e75210703b9b1ab11b15 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 2 Mar 2026 14:53:07 -0800 Subject: [PATCH 1/5] PYTHON-5114 Test suite reduce killAllSessions calls --- test/asynchronous/unified_format.py | 14 +++++++++----- test/asynchronous/utils_spec_runner.py | 9 ++++++--- test/unified_format.py | 14 +++++++++----- test/utils_spec_runner.py | 9 ++++++--- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 6ce8f852cf..a4d64f8a28 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1464,11 +1464,6 @@ async def verify_outcome(self, spec): self.assertListEqual(sorted_expected_documents, actual_documents) async def run_scenario(self, spec, uri=None): - # Kill all sessions before and after each test to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - await self.kill_all_sessions() - # Handle flaky tests. flaky_tests = [ ("PYTHON-5170", ".*test_discovery_and_monitoring.*"), @@ -1504,6 +1499,15 @@ async def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") + # Kill all sessions after each test with transactions prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + for op in spec["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addAsyncCleanup(self.kill_all_sessions) + break + # process createEntities self._uri = uri self.entity_map = EntityMapUtil(self) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 63e7e9e150..c27b4c1c23 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -621,11 +621,14 @@ async def setup_scenario(self, scenario_def): async def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions before and after each test to prevent an open + # Kill all sessions after each test with transactions prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. - await self.kill_all_sessions() - self.addAsyncCleanup(self.kill_all_sessions) + for op in test["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addAsyncCleanup(self.kill_all_sessions) + break await self.setup_scenario(scenario_def) database_name = self.get_scenario_db_name(scenario_def) collection_name = self.get_scenario_coll_name(scenario_def) diff --git a/test/unified_format.py b/test/unified_format.py index 9aee287256..31ac178cc7 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1451,11 +1451,6 @@ def verify_outcome(self, spec): self.assertListEqual(sorted_expected_documents, actual_documents) def run_scenario(self, spec, uri=None): - # Kill all sessions before and after each test to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - self.kill_all_sessions() - # Handle flaky tests. flaky_tests = [ ("PYTHON-5170", ".*test_discovery_and_monitoring.*"), @@ -1491,6 +1486,15 @@ def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") + # Kill all sessions after each test with transactions prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + for op in spec["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addCleanup(self.kill_all_sessions) + break + # process createEntities self._uri = uri self.entity_map = EntityMapUtil(self) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 9bf155e8f3..72788c4a1a 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -621,11 +621,14 @@ def setup_scenario(self, scenario_def): def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions before and after each test to prevent an open + # Kill all sessions after each test with transactions prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. - self.kill_all_sessions() - self.addCleanup(self.kill_all_sessions) + for op in test["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addCleanup(self.kill_all_sessions) + break self.setup_scenario(scenario_def) database_name = self.get_scenario_db_name(scenario_def) collection_name = self.get_scenario_coll_name(scenario_def) From 8f238421e05f1bcd943e690f340c4c7c85cb2dd2 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 5 Mar 2026 10:06:40 -0800 Subject: [PATCH 2/5] PYTHON-5114 Fix comment --- test/asynchronous/unified_format.py | 2 +- test/asynchronous/utils_spec_runner.py | 2 +- test/unified_format.py | 2 +- test/utils_spec_runner.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index a4d64f8a28..1fb93e7b86 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1499,7 +1499,7 @@ async def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") - # Kill all sessions after each test with transactions prevent an open + # Kill all sessions after each test with transactions to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. for op in spec["operations"]: diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index c27b4c1c23..f099eee12c 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -621,7 +621,7 @@ async def setup_scenario(self, scenario_def): async def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions after each test with transactions prevent an open + # Kill all sessions after each test with transactions to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. for op in test["operations"]: diff --git a/test/unified_format.py b/test/unified_format.py index 31ac178cc7..5516a7adf1 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1486,7 +1486,7 @@ def _run_scenario(self, spec, uri=None): if skip_reason is not None: raise unittest.SkipTest(f"{skip_reason}") - # Kill all sessions after each test with transactions prevent an open + # Kill all sessions after each test with transactions to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. for op in spec["operations"]: diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 72788c4a1a..34e1c95ef2 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -621,7 +621,7 @@ def setup_scenario(self, scenario_def): def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions after each test with transactions prevent an open + # Kill all sessions after each test with transactions to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. for op in test["operations"]: From c0ac7321858b2fecc95c3cd84d86e4bd12928b18 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 5 Mar 2026 11:18:28 -0800 Subject: [PATCH 3/5] PYTHON-5114 Remove unused SpecRunner class --- test/asynchronous/utils_spec_runner.py | 634 +------------------------ test/utils_spec_runner.py | 632 +----------------------- 2 files changed, 8 insertions(+), 1258 deletions(-) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index f099eee12c..344fd97536 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -16,43 +16,14 @@ from __future__ import annotations import asyncio -import functools import os import time -import unittest -from collections import abc -from inspect import iscoroutinefunction -from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs +from test.asynchronous import async_client_context from test.asynchronous.helpers import ConcurrentRunner -from test.utils_shared import ( - CMAPListener, - CompareType, - EventListener, - OvertCommandListener, - ScenarioDict, - ServerAndTopologyEventListener, - camel_to_snake, - camel_to_snake_args, - parse_spec_options, - prepare_spec_arguments, -) -from typing import List - -from bson import ObjectId, decode, encode, json_util -from bson.binary import Binary -from bson.int64 import Int64 -from bson.son import SON -from gridfs import GridFSBucket -from gridfs.asynchronous.grid_file import AsyncGridFSBucket -from pymongo.asynchronous import client_session -from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.cursor import AsyncCursor -from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from test.utils_shared import ScenarioDict + +from bson import json_util from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock -from pymongo.read_concern import ReadConcern -from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, _WriteResult -from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -219,600 +190,3 @@ def create_tests(self): self._create_tests() else: asyncio.run(self._create_tests()) - - -class AsyncSpecRunner(AsyncIntegrationTest): - mongos_clients: List - knobs: client_knobs - listener: EventListener - - async def asyncSetUp(self) -> None: - await super().asyncSetUp() - self.mongos_clients = [] - - # Speed up the tests by decreasing the heartbeat frequency. - self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - self.knobs.enable() - self.targets = {} - self.listener = None # type: ignore - self.pool_listener = None - self.server_listener = None - self.maxDiff = None - - async def asyncTearDown(self) -> None: - self.knobs.disable() - - async def set_fail_point(self, command_args): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - await self.configure_fail_point(client, command_args) - - async def targeted_fail_point(self, session, fail_point): - """Run the targetedFailPoint test operation. - - Enable the fail point on the session's pinned mongos. - """ - clients = {c.address: c for c in self.mongos_clients} - client = clients[session._pinned_address] - await self.configure_fail_point(client, fail_point) - self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) - - def assert_session_pinned(self, session): - """Run the assertSessionPinned test operation. - - Assert that the given session is pinned. - """ - self.assertIsNotNone(session._transaction.pinned_address) - - def assert_session_unpinned(self, session): - """Run the assertSessionUnpinned test operation. - - Assert that the given session is not pinned. - """ - self.assertIsNone(session._pinned_address) - self.assertIsNone(session._transaction.pinned_address) - - async def assert_collection_exists(self, database, collection): - """Run the assertCollectionExists test operation.""" - db = self.client[database] - self.assertIn(collection, await db.list_collection_names()) - - async def assert_collection_not_exists(self, database, collection): - """Run the assertCollectionNotExists test operation.""" - db = self.client[database] - self.assertNotIn(collection, await db.list_collection_names()) - - async def assert_index_exists(self, database, collection, index): - """Run the assertIndexExists test operation.""" - coll = self.client[database][collection] - self.assertIn(index, [doc["name"] async for doc in await coll.list_indexes()]) - - async def assert_index_not_exists(self, database, collection, index): - """Run the assertIndexNotExists test operation.""" - coll = self.client[database][collection] - self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()]) - - async def wait(self, ms): - """Run the "wait" test operation.""" - await asyncio.sleep(ms / 1000.0) - - def assertErrorLabelsContain(self, exc, expected_labels): - labels = [l for l in expected_labels if exc.has_error_label(l)] - self.assertEqual(labels, expected_labels) - - def assertErrorLabelsOmit(self, exc, omit_labels): - for label in omit_labels: - self.assertFalse( - exc.has_error_label(label), msg=f"error labels should not contain {label}" - ) - - async def kill_all_sessions(self): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - try: - await client.admin.command("killAllSessions", []) - except (OperationFailure, AutoReconnect): - # "operation was interrupted" by killing the command's - # own session. - # On 8.0+ killAllSessions sometimes returns a network error. - pass - - def check_command_result(self, expected_result, result): - # Only compare the keys in the expected result. - filtered_result = {} - for key in expected_result: - try: - filtered_result[key] = result[key] - except KeyError: - pass - self.assertEqual(filtered_result, expected_result) - - # TODO: factor the following function with test_crud.py. - def check_result(self, expected_result, result): - if isinstance(result, _WriteResult): - for res in expected_result: - prop = camel_to_snake(res) - # SPEC-869: Only BulkWriteResult has upserted_count. - if prop == "upserted_count" and not isinstance(result, BulkWriteResult): - if result.upserted_id is not None: - upserted_count = 1 - else: - upserted_count = 0 - self.assertEqual(upserted_count, expected_result[res], prop) - elif prop == "inserted_ids": - # BulkWriteResult does not have inserted_ids. - if isinstance(result, BulkWriteResult): - self.assertEqual(len(expected_result[res]), result.inserted_count) - else: - # InsertManyResult may be compared to [id1] from the - # crud spec or {"0": id1} from the retryable write spec. - ids = expected_result[res] - if isinstance(ids, dict): - ids = [ids[str(i)] for i in range(len(ids))] - - self.assertEqual(ids, result.inserted_ids, prop) - elif prop == "upserted_ids": - # Convert indexes from strings to integers. - ids = expected_result[res] - expected_ids = {} - for str_index in ids: - expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, result.upserted_ids, prop) - else: - self.assertEqual(getattr(result, prop), expected_result[res], prop) - - return True - else: - - def _helper(expected_result, result): - if isinstance(expected_result, abc.Mapping): - for i in expected_result.keys(): - self.assertEqual(expected_result[i], result[i]) - - elif isinstance(expected_result, list): - for i, k in zip(expected_result, result): - _helper(i, k) - else: - self.assertEqual(expected_result, result) - - _helper(expected_result, result) - return None - - def get_object_name(self, op): - """Allow subclasses to override handling of 'object' - - Transaction spec says 'object' is required. - """ - return op["object"] - - @staticmethod - def parse_options(opts): - return parse_spec_options(opts) - - async def run_operation(self, sessions, collection, operation): - original_collection = collection - name = camel_to_snake(operation["name"]) - if name == "run_command": - name = "command" - elif name == "download_by_name": - name = "open_download_stream_by_name" - elif name == "download": - name = "open_download_stream" - elif name == "map_reduce": - self.skipTest("PyMongo does not support mapReduce") - elif name == "count": - self.skipTest("PyMongo does not support count") - - database = collection.database - collection = database.get_collection(collection.name) - if "collectionOptions" in operation: - collection = collection.with_options( - **self.parse_options(operation["collectionOptions"]) - ) - - object_name = self.get_object_name(operation) - if object_name == "gridfsbucket": - # Only create the GridFSBucket when we need it (for the gridfs - # retryable reads tests). - obj = AsyncGridFSBucket(database, bucket_name=collection.name) - else: - objects = { - "client": database.client, - "database": database, - "collection": collection, - "testRunner": self, - } - objects.update(sessions) - obj = objects[object_name] - - # Combine arguments with options and handle special cases. - arguments = operation.get("arguments", {}) - arguments.update(arguments.pop("options", {})) - self.parse_options(arguments) - - cmd = getattr(obj, name) - - with_txn_callback = functools.partial( - self.run_operations, sessions, original_collection, in_with_transaction=True - ) - prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) - - if name == "run_on_thread": - args = {"sessions": sessions, "collection": collection} - args.update(arguments) - arguments = args - - if not _IS_SYNC and iscoroutinefunction(cmd): - result = await cmd(**dict(arguments)) - else: - result = cmd(**dict(arguments)) - # Cleanup open change stream cursors. - if name == "watch": - self.addAsyncCleanup(result.close) - - if name == "aggregate": - if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: - # Read from the primary to ensure causal consistency. - out = collection.database.get_collection( - arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY - ) - return out.find() - if "download" in name: - result = Binary(result.read()) - - if isinstance(result, AsyncCursor) or isinstance(result, AsyncCommandCursor): - return await result.to_list() - - return result - - def allowable_errors(self, op): - """Allow encryption spec to override expected error classes.""" - return (PyMongoError,) - - async def _run_op(self, sessions, collection, op, in_with_transaction): - expected_result = op.get("result") - if expect_error(op): - with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: - await self.run_operation(sessions, collection, op.copy()) - exc = context.exception - if expect_error_message(expected_result): - if isinstance(exc, BulkWriteError): - errmsg = str(exc.details).lower() - else: - errmsg = str(exc).lower() - self.assertIn(expected_result["errorContains"].lower(), errmsg) - if expect_error_code(expected_result): - self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) - if expect_error_labels_contain(expected_result): - self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) - if expect_error_labels_omit(expected_result): - self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) - if expect_timeout_error(expected_result): - self.assertIsInstance(exc, PyMongoError) - if not exc.timeout: - # Re-raise the exception for better diagnostics. - raise exc - - # Reraise the exception if we're in the with_transaction - # callback. - if in_with_transaction: - raise context.exception - else: - result = await self.run_operation(sessions, collection, op.copy()) - if "result" in op: - if op["name"] == "runCommand": - self.check_command_result(expected_result, result) - else: - self.check_result(expected_result, result) - - async def run_operations(self, sessions, collection, ops, in_with_transaction=False): - for op in ops: - await self._run_op(sessions, collection, op, in_with_transaction) - - # TODO: factor with test_command_monitoring.py - def check_events(self, test, listener, session_ids): - events = listener.started_events - if not len(test["expectations"]): - return - - # Give a nicer message when there are missing or extra events - cmds = decode_raw([event.command for event in events]) - self.assertEqual(len(events), len(test["expectations"]), cmds) - for i, expectation in enumerate(test["expectations"]): - event_type = next(iter(expectation)) - event = events[i] - - # The tests substitute 42 for any number other than 0. - if event.command_name == "getMore" and event.command["getMore"]: - event.command["getMore"] = Int64(42) - elif event.command_name == "killCursors": - event.command["cursors"] = [Int64(42)] - elif event.command_name == "update": - # TODO: remove this once PYTHON-1744 is done. - # Add upsert and multi fields back into expectations. - updates = expectation[event_type]["command"]["updates"] - for update in updates: - update.setdefault("upsert", False) - update.setdefault("multi", False) - - # Replace afterClusterTime: 42 with actual afterClusterTime. - expected_cmd = expectation[event_type]["command"] - expected_read_concern = expected_cmd.get("readConcern") - if expected_read_concern is not None: - time = expected_read_concern.get("afterClusterTime") - if time == 42: - actual_time = event.command.get("readConcern", {}).get("afterClusterTime") - if actual_time is not None: - expected_read_concern["afterClusterTime"] = actual_time - - recovery_token = expected_cmd.get("recoveryToken") - if recovery_token == 42: - expected_cmd["recoveryToken"] = CompareType(dict) - - # Replace lsid with a name like "session0" to match test. - if "lsid" in event.command: - for name, lsid in session_ids.items(): - if event.command["lsid"] == lsid: - event.command["lsid"] = name - break - - for attr, expected in expectation[event_type].items(): - actual = getattr(event, attr) - expected = wrap_types(expected) - if isinstance(expected, dict): - for key, val in expected.items(): - if val is None: - if key in actual: - self.fail(f"Unexpected key [{key}] in {actual!r}") - elif key not in actual: - self.fail(f"Expected key [{key}] in {actual!r}") - else: - self.assertEqual( - val, decode_raw(actual[key]), f"Key [{key}] in {actual}" - ) - else: - self.assertEqual(actual, expected) - - def maybe_skip_scenario(self, test): - if test.get("skipReason"): - self.skipTest(test.get("skipReason")) - - def get_scenario_db_name(self, scenario_def): - """Allow subclasses to override a test's database name.""" - return scenario_def["database_name"] - - def get_scenario_coll_name(self, scenario_def): - """Allow subclasses to override a test's collection name.""" - return scenario_def["collection_name"] - - def get_outcome_coll_name(self, outcome, collection): - """Allow subclasses to override outcome collection.""" - return collection.name - - async def run_test_ops(self, sessions, collection, test): - """Added to allow retryable writes spec to override a test's - operation. - """ - await self.run_operations(sessions, collection, test["operations"]) - - def parse_client_options(self, opts): - """Allow encryption spec to override a clientOptions parsing.""" - return opts - - async def setup_scenario(self, scenario_def): - """Allow specs to override a test's setup.""" - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - documents = scenario_def["data"] - - # Setup the collection with as few majority writes as possible. - db = async_client_context.client.get_database(db_name) - coll_exists = bool(await db.list_collection_names(filter={"name": coll_name})) - if coll_exists: - await db[coll_name].delete_many({}) - # Only use majority wc only on the final write. - wc = WriteConcern(w="majority") - if documents: - db.get_collection(coll_name, write_concern=wc).insert_many(documents) - elif not coll_exists: - # Ensure collection exists. - await db.create_collection(coll_name, write_concern=wc) - - async def run_scenario(self, scenario_def, test): - self.maybe_skip_scenario(test) - - # Kill all sessions after each test with transactions to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - for op in test["operations"]: - name = op["name"] - if name == "startTransaction" or name == "withTransaction": - self.addAsyncCleanup(self.kill_all_sessions) - break - await self.setup_scenario(scenario_def) - database_name = self.get_scenario_db_name(scenario_def) - collection_name = self.get_scenario_coll_name(scenario_def) - # SPEC-1245 workaround StaleDbVersion on distinct - for c in self.mongos_clients: - await c[database_name][collection_name].distinct("x") - - # Configure the fail point before creating the client. - if "failPoint" in test: - fp = test["failPoint"] - await self.set_fail_point(fp) - self.addAsyncCleanup( - self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - listener = OvertCommandListener() - pool_listener = CMAPListener() - server_listener = ServerAndTopologyEventListener() - # Create a new client, to avoid interference from pooled sessions. - client_options = self.parse_client_options(test["clientOptions"]) - use_multi_mongos = test["useMultipleMongoses"] - host = None - if use_multi_mongos: - if async_client_context.load_balancer: - host = async_client_context.MULTI_MONGOS_LB_URI - elif async_client_context.is_mongos: - host = async_client_context.mongos_seeds() - client = await self.async_rs_client( - h=host, event_listeners=[listener, pool_listener, server_listener], **client_options - ) - self.scenario_client = client - self.listener = listener - self.pool_listener = pool_listener - self.server_listener = server_listener - - # Create session0 and session1. - sessions = {} - session_ids = {} - for i in range(2): - # Don't attempt to create sessions if they are not supported by - # the running server version. - if not async_client_context.sessions_enabled: - break - session_name = "session%d" % i - opts = camel_to_snake_args(test["sessionOptions"][session_name]) - if "default_transaction_options" in opts: - txn_opts = self.parse_options(opts["default_transaction_options"]) - txn_opts = client_session.TransactionOptions(**txn_opts) - opts["default_transaction_options"] = txn_opts - - s = client.start_session(**dict(opts)) - - sessions[session_name] = s - # Store lsid so we can access it after end_session, in check_events. - session_ids[session_name] = s.session_id - - self.addAsyncCleanup(end_sessions, sessions) - - collection = client[database_name][collection_name] - await self.run_test_ops(sessions, collection, test) - - await end_sessions(sessions) - - self.check_events(test, listener, session_ids) - - # Disable fail points. - if "failPoint" in test: - fp = test["failPoint"] - await self.set_fail_point( - {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - # Assert final state is expected. - outcome = test["outcome"] - expected_c = outcome.get("collection") - if expected_c is not None: - outcome_coll_name = self.get_outcome_coll_name(outcome, collection) - - # Read from the primary with local read concern to ensure causal - # consistency. - outcome_coll = async_client_context.client[collection.database.name].get_collection( - outcome_coll_name, - read_preference=ReadPreference.PRIMARY, - read_concern=ReadConcern("local"), - ) - actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list() - - # The expected data needs to be the left hand side here otherwise - # CompareType(Binary) doesn't work. - self.assertEqual(wrap_types(expected_c["data"]), actual_data) - - -def expect_any_error(op): - if isinstance(op, dict): - return op.get("error") - - return False - - -def expect_error_message(expected_result): - if isinstance(expected_result, dict): - return isinstance(expected_result["errorContains"], str) - - return False - - -def expect_error_code(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorCodeName"] - - return False - - -def expect_error_labels_contain(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsContain"] - - return False - - -def expect_error_labels_omit(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsOmit"] - - return False - - -def expect_timeout_error(expected_result): - if isinstance(expected_result, dict): - return expected_result["isTimeoutError"] - - return False - - -def expect_error(op): - expected_result = op.get("result") - return ( - expect_any_error(op) - or expect_error_message(expected_result) - or expect_error_code(expected_result) - or expect_error_labels_contain(expected_result) - or expect_error_labels_omit(expected_result) - or expect_timeout_error(expected_result) - ) - - -async def end_sessions(sessions): - for s in sessions.values(): - # Aborts the transaction if it's open. - await s.end_session() - - -def decode_raw(val): - """Decode RawBSONDocuments in the given container.""" - if isinstance(val, (list, abc.Mapping)): - return decode(encode({"v": val}))["v"] - return val - - -TYPES = { - "binData": Binary, - "long": Int64, - "int": int, - "string": str, - "objectId": ObjectId, - "object": dict, - "array": list, -} - - -def wrap_types(val): - """Support $$type assertion in command results.""" - if isinstance(val, list): - return [wrap_types(v) for v in val] - if isinstance(val, abc.Mapping): - typ = val.get("$$type") - if typ: - if isinstance(typ, str): - types = TYPES[typ] - else: - types = tuple(TYPES[t] for t in typ) - return CompareType(types) - d = {} - for key in val: - d[key] = wrap_types(val[key]) - return d - return val diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 34e1c95ef2..95e580cef9 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -16,43 +16,14 @@ from __future__ import annotations import asyncio -import functools import os import time -import unittest -from collections import abc -from inspect import iscoroutinefunction -from test import IntegrationTest, client_context, client_knobs +from test import client_context from test.helpers import ConcurrentRunner -from test.utils_shared import ( - CMAPListener, - CompareType, - EventListener, - OvertCommandListener, - ScenarioDict, - ServerAndTopologyEventListener, - camel_to_snake, - camel_to_snake_args, - parse_spec_options, - prepare_spec_arguments, -) -from typing import List - -from bson import ObjectId, decode, encode, json_util -from bson.binary import Binary -from bson.int64 import Int64 -from bson.son import SON -from gridfs import GridFSBucket -from gridfs.synchronous.grid_file import GridFSBucket -from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from test.utils_shared import ScenarioDict + +from bson import json_util from pymongo.lock import _cond_wait, _create_condition, _create_lock -from pymongo.read_concern import ReadConcern -from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, _WriteResult -from pymongo.synchronous import client_session -from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.cursor import Cursor -from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -219,598 +190,3 @@ def create_tests(self): self._create_tests() else: asyncio.run(self._create_tests()) - - -class SpecRunner(IntegrationTest): - mongos_clients: List - knobs: client_knobs - listener: EventListener - - def setUp(self) -> None: - super().setUp() - self.mongos_clients = [] - - # Speed up the tests by decreasing the heartbeat frequency. - self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - self.knobs.enable() - self.targets = {} - self.listener = None # type: ignore - self.pool_listener = None - self.server_listener = None - self.maxDiff = None - - def tearDown(self) -> None: - self.knobs.disable() - - def set_fail_point(self, command_args): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - self.configure_fail_point(client, command_args) - - def targeted_fail_point(self, session, fail_point): - """Run the targetedFailPoint test operation. - - Enable the fail point on the session's pinned mongos. - """ - clients = {c.address: c for c in self.mongos_clients} - client = clients[session._pinned_address] - self.configure_fail_point(client, fail_point) - self.addCleanup(self.set_fail_point, {"mode": "off"}) - - def assert_session_pinned(self, session): - """Run the assertSessionPinned test operation. - - Assert that the given session is pinned. - """ - self.assertIsNotNone(session._transaction.pinned_address) - - def assert_session_unpinned(self, session): - """Run the assertSessionUnpinned test operation. - - Assert that the given session is not pinned. - """ - self.assertIsNone(session._pinned_address) - self.assertIsNone(session._transaction.pinned_address) - - def assert_collection_exists(self, database, collection): - """Run the assertCollectionExists test operation.""" - db = self.client[database] - self.assertIn(collection, db.list_collection_names()) - - def assert_collection_not_exists(self, database, collection): - """Run the assertCollectionNotExists test operation.""" - db = self.client[database] - self.assertNotIn(collection, db.list_collection_names()) - - def assert_index_exists(self, database, collection, index): - """Run the assertIndexExists test operation.""" - coll = self.client[database][collection] - self.assertIn(index, [doc["name"] for doc in coll.list_indexes()]) - - def assert_index_not_exists(self, database, collection, index): - """Run the assertIndexNotExists test operation.""" - coll = self.client[database][collection] - self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()]) - - def wait(self, ms): - """Run the "wait" test operation.""" - time.sleep(ms / 1000.0) - - def assertErrorLabelsContain(self, exc, expected_labels): - labels = [l for l in expected_labels if exc.has_error_label(l)] - self.assertEqual(labels, expected_labels) - - def assertErrorLabelsOmit(self, exc, omit_labels): - for label in omit_labels: - self.assertFalse( - exc.has_error_label(label), msg=f"error labels should not contain {label}" - ) - - def kill_all_sessions(self): - clients = self.mongos_clients if self.mongos_clients else [self.client] - for client in clients: - try: - client.admin.command("killAllSessions", []) - except (OperationFailure, AutoReconnect): - # "operation was interrupted" by killing the command's - # own session. - # On 8.0+ killAllSessions sometimes returns a network error. - pass - - def check_command_result(self, expected_result, result): - # Only compare the keys in the expected result. - filtered_result = {} - for key in expected_result: - try: - filtered_result[key] = result[key] - except KeyError: - pass - self.assertEqual(filtered_result, expected_result) - - # TODO: factor the following function with test_crud.py. - def check_result(self, expected_result, result): - if isinstance(result, _WriteResult): - for res in expected_result: - prop = camel_to_snake(res) - # SPEC-869: Only BulkWriteResult has upserted_count. - if prop == "upserted_count" and not isinstance(result, BulkWriteResult): - if result.upserted_id is not None: - upserted_count = 1 - else: - upserted_count = 0 - self.assertEqual(upserted_count, expected_result[res], prop) - elif prop == "inserted_ids": - # BulkWriteResult does not have inserted_ids. - if isinstance(result, BulkWriteResult): - self.assertEqual(len(expected_result[res]), result.inserted_count) - else: - # InsertManyResult may be compared to [id1] from the - # crud spec or {"0": id1} from the retryable write spec. - ids = expected_result[res] - if isinstance(ids, dict): - ids = [ids[str(i)] for i in range(len(ids))] - - self.assertEqual(ids, result.inserted_ids, prop) - elif prop == "upserted_ids": - # Convert indexes from strings to integers. - ids = expected_result[res] - expected_ids = {} - for str_index in ids: - expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, result.upserted_ids, prop) - else: - self.assertEqual(getattr(result, prop), expected_result[res], prop) - - return True - else: - - def _helper(expected_result, result): - if isinstance(expected_result, abc.Mapping): - for i in expected_result.keys(): - self.assertEqual(expected_result[i], result[i]) - - elif isinstance(expected_result, list): - for i, k in zip(expected_result, result): - _helper(i, k) - else: - self.assertEqual(expected_result, result) - - _helper(expected_result, result) - return None - - def get_object_name(self, op): - """Allow subclasses to override handling of 'object' - - Transaction spec says 'object' is required. - """ - return op["object"] - - @staticmethod - def parse_options(opts): - return parse_spec_options(opts) - - def run_operation(self, sessions, collection, operation): - original_collection = collection - name = camel_to_snake(operation["name"]) - if name == "run_command": - name = "command" - elif name == "download_by_name": - name = "open_download_stream_by_name" - elif name == "download": - name = "open_download_stream" - elif name == "map_reduce": - self.skipTest("PyMongo does not support mapReduce") - elif name == "count": - self.skipTest("PyMongo does not support count") - - database = collection.database - collection = database.get_collection(collection.name) - if "collectionOptions" in operation: - collection = collection.with_options( - **self.parse_options(operation["collectionOptions"]) - ) - - object_name = self.get_object_name(operation) - if object_name == "gridfsbucket": - # Only create the GridFSBucket when we need it (for the gridfs - # retryable reads tests). - obj = GridFSBucket(database, bucket_name=collection.name) - else: - objects = { - "client": database.client, - "database": database, - "collection": collection, - "testRunner": self, - } - objects.update(sessions) - obj = objects[object_name] - - # Combine arguments with options and handle special cases. - arguments = operation.get("arguments", {}) - arguments.update(arguments.pop("options", {})) - self.parse_options(arguments) - - cmd = getattr(obj, name) - - with_txn_callback = functools.partial( - self.run_operations, sessions, original_collection, in_with_transaction=True - ) - prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) - - if name == "run_on_thread": - args = {"sessions": sessions, "collection": collection} - args.update(arguments) - arguments = args - - if not _IS_SYNC and iscoroutinefunction(cmd): - result = cmd(**dict(arguments)) - else: - result = cmd(**dict(arguments)) - # Cleanup open change stream cursors. - if name == "watch": - self.addCleanup(result.close) - - if name == "aggregate": - if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: - # Read from the primary to ensure causal consistency. - out = collection.database.get_collection( - arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY - ) - return out.find() - if "download" in name: - result = Binary(result.read()) - - if isinstance(result, Cursor) or isinstance(result, CommandCursor): - return result.to_list() - - return result - - def allowable_errors(self, op): - """Allow encryption spec to override expected error classes.""" - return (PyMongoError,) - - def _run_op(self, sessions, collection, op, in_with_transaction): - expected_result = op.get("result") - if expect_error(op): - with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: - self.run_operation(sessions, collection, op.copy()) - exc = context.exception - if expect_error_message(expected_result): - if isinstance(exc, BulkWriteError): - errmsg = str(exc.details).lower() - else: - errmsg = str(exc).lower() - self.assertIn(expected_result["errorContains"].lower(), errmsg) - if expect_error_code(expected_result): - self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) - if expect_error_labels_contain(expected_result): - self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) - if expect_error_labels_omit(expected_result): - self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) - if expect_timeout_error(expected_result): - self.assertIsInstance(exc, PyMongoError) - if not exc.timeout: - # Re-raise the exception for better diagnostics. - raise exc - - # Reraise the exception if we're in the with_transaction - # callback. - if in_with_transaction: - raise context.exception - else: - result = self.run_operation(sessions, collection, op.copy()) - if "result" in op: - if op["name"] == "runCommand": - self.check_command_result(expected_result, result) - else: - self.check_result(expected_result, result) - - def run_operations(self, sessions, collection, ops, in_with_transaction=False): - for op in ops: - self._run_op(sessions, collection, op, in_with_transaction) - - # TODO: factor with test_command_monitoring.py - def check_events(self, test, listener, session_ids): - events = listener.started_events - if not len(test["expectations"]): - return - - # Give a nicer message when there are missing or extra events - cmds = decode_raw([event.command for event in events]) - self.assertEqual(len(events), len(test["expectations"]), cmds) - for i, expectation in enumerate(test["expectations"]): - event_type = next(iter(expectation)) - event = events[i] - - # The tests substitute 42 for any number other than 0. - if event.command_name == "getMore" and event.command["getMore"]: - event.command["getMore"] = Int64(42) - elif event.command_name == "killCursors": - event.command["cursors"] = [Int64(42)] - elif event.command_name == "update": - # TODO: remove this once PYTHON-1744 is done. - # Add upsert and multi fields back into expectations. - updates = expectation[event_type]["command"]["updates"] - for update in updates: - update.setdefault("upsert", False) - update.setdefault("multi", False) - - # Replace afterClusterTime: 42 with actual afterClusterTime. - expected_cmd = expectation[event_type]["command"] - expected_read_concern = expected_cmd.get("readConcern") - if expected_read_concern is not None: - time = expected_read_concern.get("afterClusterTime") - if time == 42: - actual_time = event.command.get("readConcern", {}).get("afterClusterTime") - if actual_time is not None: - expected_read_concern["afterClusterTime"] = actual_time - - recovery_token = expected_cmd.get("recoveryToken") - if recovery_token == 42: - expected_cmd["recoveryToken"] = CompareType(dict) - - # Replace lsid with a name like "session0" to match test. - if "lsid" in event.command: - for name, lsid in session_ids.items(): - if event.command["lsid"] == lsid: - event.command["lsid"] = name - break - - for attr, expected in expectation[event_type].items(): - actual = getattr(event, attr) - expected = wrap_types(expected) - if isinstance(expected, dict): - for key, val in expected.items(): - if val is None: - if key in actual: - self.fail(f"Unexpected key [{key}] in {actual!r}") - elif key not in actual: - self.fail(f"Expected key [{key}] in {actual!r}") - else: - self.assertEqual( - val, decode_raw(actual[key]), f"Key [{key}] in {actual}" - ) - else: - self.assertEqual(actual, expected) - - def maybe_skip_scenario(self, test): - if test.get("skipReason"): - self.skipTest(test.get("skipReason")) - - def get_scenario_db_name(self, scenario_def): - """Allow subclasses to override a test's database name.""" - return scenario_def["database_name"] - - def get_scenario_coll_name(self, scenario_def): - """Allow subclasses to override a test's collection name.""" - return scenario_def["collection_name"] - - def get_outcome_coll_name(self, outcome, collection): - """Allow subclasses to override outcome collection.""" - return collection.name - - def run_test_ops(self, sessions, collection, test): - """Added to allow retryable writes spec to override a test's - operation. - """ - self.run_operations(sessions, collection, test["operations"]) - - def parse_client_options(self, opts): - """Allow encryption spec to override a clientOptions parsing.""" - return opts - - def setup_scenario(self, scenario_def): - """Allow specs to override a test's setup.""" - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - documents = scenario_def["data"] - - # Setup the collection with as few majority writes as possible. - db = client_context.client.get_database(db_name) - coll_exists = bool(db.list_collection_names(filter={"name": coll_name})) - if coll_exists: - db[coll_name].delete_many({}) - # Only use majority wc only on the final write. - wc = WriteConcern(w="majority") - if documents: - db.get_collection(coll_name, write_concern=wc).insert_many(documents) - elif not coll_exists: - # Ensure collection exists. - db.create_collection(coll_name, write_concern=wc) - - def run_scenario(self, scenario_def, test): - self.maybe_skip_scenario(test) - - # Kill all sessions after each test with transactions to prevent an open - # transaction (from a test failure) from blocking collection/database - # operations during test set up and tear down. - for op in test["operations"]: - name = op["name"] - if name == "startTransaction" or name == "withTransaction": - self.addCleanup(self.kill_all_sessions) - break - self.setup_scenario(scenario_def) - database_name = self.get_scenario_db_name(scenario_def) - collection_name = self.get_scenario_coll_name(scenario_def) - # SPEC-1245 workaround StaleDbVersion on distinct - for c in self.mongos_clients: - c[database_name][collection_name].distinct("x") - - # Configure the fail point before creating the client. - if "failPoint" in test: - fp = test["failPoint"] - self.set_fail_point(fp) - self.addCleanup( - self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} - ) - - listener = OvertCommandListener() - pool_listener = CMAPListener() - server_listener = ServerAndTopologyEventListener() - # Create a new client, to avoid interference from pooled sessions. - client_options = self.parse_client_options(test["clientOptions"]) - use_multi_mongos = test["useMultipleMongoses"] - host = None - if use_multi_mongos: - if client_context.load_balancer: - host = client_context.MULTI_MONGOS_LB_URI - elif client_context.is_mongos: - host = client_context.mongos_seeds() - client = self.rs_client( - h=host, event_listeners=[listener, pool_listener, server_listener], **client_options - ) - self.scenario_client = client - self.listener = listener - self.pool_listener = pool_listener - self.server_listener = server_listener - - # Create session0 and session1. - sessions = {} - session_ids = {} - for i in range(2): - # Don't attempt to create sessions if they are not supported by - # the running server version. - if not client_context.sessions_enabled: - break - session_name = "session%d" % i - opts = camel_to_snake_args(test["sessionOptions"][session_name]) - if "default_transaction_options" in opts: - txn_opts = self.parse_options(opts["default_transaction_options"]) - txn_opts = client_session.TransactionOptions(**txn_opts) - opts["default_transaction_options"] = txn_opts - - s = client.start_session(**dict(opts)) - - sessions[session_name] = s - # Store lsid so we can access it after end_session, in check_events. - session_ids[session_name] = s.session_id - - self.addCleanup(end_sessions, sessions) - - collection = client[database_name][collection_name] - self.run_test_ops(sessions, collection, test) - - end_sessions(sessions) - - self.check_events(test, listener, session_ids) - - # Disable fail points. - if "failPoint" in test: - fp = test["failPoint"] - self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"}) - - # Assert final state is expected. - outcome = test["outcome"] - expected_c = outcome.get("collection") - if expected_c is not None: - outcome_coll_name = self.get_outcome_coll_name(outcome, collection) - - # Read from the primary with local read concern to ensure causal - # consistency. - outcome_coll = client_context.client[collection.database.name].get_collection( - outcome_coll_name, - read_preference=ReadPreference.PRIMARY, - read_concern=ReadConcern("local"), - ) - actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() - - # The expected data needs to be the left hand side here otherwise - # CompareType(Binary) doesn't work. - self.assertEqual(wrap_types(expected_c["data"]), actual_data) - - -def expect_any_error(op): - if isinstance(op, dict): - return op.get("error") - - return False - - -def expect_error_message(expected_result): - if isinstance(expected_result, dict): - return isinstance(expected_result["errorContains"], str) - - return False - - -def expect_error_code(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorCodeName"] - - return False - - -def expect_error_labels_contain(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsContain"] - - return False - - -def expect_error_labels_omit(expected_result): - if isinstance(expected_result, dict): - return expected_result["errorLabelsOmit"] - - return False - - -def expect_timeout_error(expected_result): - if isinstance(expected_result, dict): - return expected_result["isTimeoutError"] - - return False - - -def expect_error(op): - expected_result = op.get("result") - return ( - expect_any_error(op) - or expect_error_message(expected_result) - or expect_error_code(expected_result) - or expect_error_labels_contain(expected_result) - or expect_error_labels_omit(expected_result) - or expect_timeout_error(expected_result) - ) - - -def end_sessions(sessions): - for s in sessions.values(): - # Aborts the transaction if it's open. - s.end_session() - - -def decode_raw(val): - """Decode RawBSONDocuments in the given container.""" - if isinstance(val, (list, abc.Mapping)): - return decode(encode({"v": val}))["v"] - return val - - -TYPES = { - "binData": Binary, - "long": Int64, - "int": int, - "string": str, - "objectId": ObjectId, - "object": dict, - "array": list, -} - - -def wrap_types(val): - """Support $$type assertion in command results.""" - if isinstance(val, list): - return [wrap_types(v) for v in val] - if isinstance(val, abc.Mapping): - typ = val.get("$$type") - if typ: - if isinstance(typ, str): - types = TYPES[typ] - else: - types = tuple(TYPES[t] for t in typ) - return CompareType(types) - d = {} - for key in val: - d[key] = wrap_types(val[key]) - return d - return val From dd08d2d36a4c101723c2a96ac7ac6e91746f0bc9 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 9 Mar 2026 11:00:11 -0700 Subject: [PATCH 4/5] Revert "PYTHON-5114 Remove unused SpecRunner class" This reverts commit c0ac7321858b2fecc95c3cd84d86e4bd12928b18. --- test/asynchronous/utils_spec_runner.py | 634 ++++++++++++++++++++++++- test/utils_spec_runner.py | 632 +++++++++++++++++++++++- 2 files changed, 1258 insertions(+), 8 deletions(-) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 344fd97536..f099eee12c 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -16,14 +16,43 @@ from __future__ import annotations import asyncio +import functools import os import time -from test.asynchronous import async_client_context +import unittest +from collections import abc +from inspect import iscoroutinefunction +from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs from test.asynchronous.helpers import ConcurrentRunner -from test.utils_shared import ScenarioDict - -from bson import json_util +from test.utils_shared import ( + CMAPListener, + CompareType, + EventListener, + OvertCommandListener, + ScenarioDict, + ServerAndTopologyEventListener, + camel_to_snake, + camel_to_snake_args, + parse_spec_options, + prepare_spec_arguments, +) +from typing import List + +from bson import ObjectId, decode, encode, json_util +from bson.binary import Binary +from bson.int64 import Int64 +from bson.son import SON +from gridfs import GridFSBucket +from gridfs.asynchronous.grid_file import AsyncGridFSBucket +from pymongo.asynchronous import client_session +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.cursor import AsyncCursor +from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.results import BulkWriteResult, _WriteResult +from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -190,3 +219,600 @@ def create_tests(self): self._create_tests() else: asyncio.run(self._create_tests()) + + +class AsyncSpecRunner(AsyncIntegrationTest): + mongos_clients: List + knobs: client_knobs + listener: EventListener + + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] + + # Speed up the tests by decreasing the heartbeat frequency. + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.targets = {} + self.listener = None # type: ignore + self.pool_listener = None + self.server_listener = None + self.maxDiff = None + + async def asyncTearDown(self) -> None: + self.knobs.disable() + + async def set_fail_point(self, command_args): + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + await self.configure_fail_point(client, command_args) + + async def targeted_fail_point(self, session, fail_point): + """Run the targetedFailPoint test operation. + + Enable the fail point on the session's pinned mongos. + """ + clients = {c.address: c for c in self.mongos_clients} + client = clients[session._pinned_address] + await self.configure_fail_point(client, fail_point) + self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) + + def assert_session_pinned(self, session): + """Run the assertSessionPinned test operation. + + Assert that the given session is pinned. + """ + self.assertIsNotNone(session._transaction.pinned_address) + + def assert_session_unpinned(self, session): + """Run the assertSessionUnpinned test operation. + + Assert that the given session is not pinned. + """ + self.assertIsNone(session._pinned_address) + self.assertIsNone(session._transaction.pinned_address) + + async def assert_collection_exists(self, database, collection): + """Run the assertCollectionExists test operation.""" + db = self.client[database] + self.assertIn(collection, await db.list_collection_names()) + + async def assert_collection_not_exists(self, database, collection): + """Run the assertCollectionNotExists test operation.""" + db = self.client[database] + self.assertNotIn(collection, await db.list_collection_names()) + + async def assert_index_exists(self, database, collection, index): + """Run the assertIndexExists test operation.""" + coll = self.client[database][collection] + self.assertIn(index, [doc["name"] async for doc in await coll.list_indexes()]) + + async def assert_index_not_exists(self, database, collection, index): + """Run the assertIndexNotExists test operation.""" + coll = self.client[database][collection] + self.assertNotIn(index, [doc["name"] async for doc in await coll.list_indexes()]) + + async def wait(self, ms): + """Run the "wait" test operation.""" + await asyncio.sleep(ms / 1000.0) + + def assertErrorLabelsContain(self, exc, expected_labels): + labels = [l for l in expected_labels if exc.has_error_label(l)] + self.assertEqual(labels, expected_labels) + + def assertErrorLabelsOmit(self, exc, omit_labels): + for label in omit_labels: + self.assertFalse( + exc.has_error_label(label), msg=f"error labels should not contain {label}" + ) + + async def kill_all_sessions(self): + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + try: + await client.admin.command("killAllSessions", []) + except (OperationFailure, AutoReconnect): + # "operation was interrupted" by killing the command's + # own session. + # On 8.0+ killAllSessions sometimes returns a network error. + pass + + def check_command_result(self, expected_result, result): + # Only compare the keys in the expected result. + filtered_result = {} + for key in expected_result: + try: + filtered_result[key] = result[key] + except KeyError: + pass + self.assertEqual(filtered_result, expected_result) + + # TODO: factor the following function with test_crud.py. + def check_result(self, expected_result, result): + if isinstance(result, _WriteResult): + for res in expected_result: + prop = camel_to_snake(res) + # SPEC-869: Only BulkWriteResult has upserted_count. + if prop == "upserted_count" and not isinstance(result, BulkWriteResult): + if result.upserted_id is not None: + upserted_count = 1 + else: + upserted_count = 0 + self.assertEqual(upserted_count, expected_result[res], prop) + elif prop == "inserted_ids": + # BulkWriteResult does not have inserted_ids. + if isinstance(result, BulkWriteResult): + self.assertEqual(len(expected_result[res]), result.inserted_count) + else: + # InsertManyResult may be compared to [id1] from the + # crud spec or {"0": id1} from the retryable write spec. + ids = expected_result[res] + if isinstance(ids, dict): + ids = [ids[str(i)] for i in range(len(ids))] + + self.assertEqual(ids, result.inserted_ids, prop) + elif prop == "upserted_ids": + # Convert indexes from strings to integers. + ids = expected_result[res] + expected_ids = {} + for str_index in ids: + expected_ids[int(str_index)] = ids[str_index] + self.assertEqual(expected_ids, result.upserted_ids, prop) + else: + self.assertEqual(getattr(result, prop), expected_result[res], prop) + + return True + else: + + def _helper(expected_result, result): + if isinstance(expected_result, abc.Mapping): + for i in expected_result.keys(): + self.assertEqual(expected_result[i], result[i]) + + elif isinstance(expected_result, list): + for i, k in zip(expected_result, result): + _helper(i, k) + else: + self.assertEqual(expected_result, result) + + _helper(expected_result, result) + return None + + def get_object_name(self, op): + """Allow subclasses to override handling of 'object' + + Transaction spec says 'object' is required. + """ + return op["object"] + + @staticmethod + def parse_options(opts): + return parse_spec_options(opts) + + async def run_operation(self, sessions, collection, operation): + original_collection = collection + name = camel_to_snake(operation["name"]) + if name == "run_command": + name = "command" + elif name == "download_by_name": + name = "open_download_stream_by_name" + elif name == "download": + name = "open_download_stream" + elif name == "map_reduce": + self.skipTest("PyMongo does not support mapReduce") + elif name == "count": + self.skipTest("PyMongo does not support count") + + database = collection.database + collection = database.get_collection(collection.name) + if "collectionOptions" in operation: + collection = collection.with_options( + **self.parse_options(operation["collectionOptions"]) + ) + + object_name = self.get_object_name(operation) + if object_name == "gridfsbucket": + # Only create the GridFSBucket when we need it (for the gridfs + # retryable reads tests). + obj = AsyncGridFSBucket(database, bucket_name=collection.name) + else: + objects = { + "client": database.client, + "database": database, + "collection": collection, + "testRunner": self, + } + objects.update(sessions) + obj = objects[object_name] + + # Combine arguments with options and handle special cases. + arguments = operation.get("arguments", {}) + arguments.update(arguments.pop("options", {})) + self.parse_options(arguments) + + cmd = getattr(obj, name) + + with_txn_callback = functools.partial( + self.run_operations, sessions, original_collection, in_with_transaction=True + ) + prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) + + if name == "run_on_thread": + args = {"sessions": sessions, "collection": collection} + args.update(arguments) + arguments = args + + if not _IS_SYNC and iscoroutinefunction(cmd): + result = await cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) + # Cleanup open change stream cursors. + if name == "watch": + self.addAsyncCleanup(result.close) + + if name == "aggregate": + if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: + # Read from the primary to ensure causal consistency. + out = collection.database.get_collection( + arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY + ) + return out.find() + if "download" in name: + result = Binary(result.read()) + + if isinstance(result, AsyncCursor) or isinstance(result, AsyncCommandCursor): + return await result.to_list() + + return result + + def allowable_errors(self, op): + """Allow encryption spec to override expected error classes.""" + return (PyMongoError,) + + async def _run_op(self, sessions, collection, op, in_with_transaction): + expected_result = op.get("result") + if expect_error(op): + with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: + await self.run_operation(sessions, collection, op.copy()) + exc = context.exception + if expect_error_message(expected_result): + if isinstance(exc, BulkWriteError): + errmsg = str(exc.details).lower() + else: + errmsg = str(exc).lower() + self.assertIn(expected_result["errorContains"].lower(), errmsg) + if expect_error_code(expected_result): + self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) + if expect_error_labels_contain(expected_result): + self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) + if expect_error_labels_omit(expected_result): + self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) + if expect_timeout_error(expected_result): + self.assertIsInstance(exc, PyMongoError) + if not exc.timeout: + # Re-raise the exception for better diagnostics. + raise exc + + # Reraise the exception if we're in the with_transaction + # callback. + if in_with_transaction: + raise context.exception + else: + result = await self.run_operation(sessions, collection, op.copy()) + if "result" in op: + if op["name"] == "runCommand": + self.check_command_result(expected_result, result) + else: + self.check_result(expected_result, result) + + async def run_operations(self, sessions, collection, ops, in_with_transaction=False): + for op in ops: + await self._run_op(sessions, collection, op, in_with_transaction) + + # TODO: factor with test_command_monitoring.py + def check_events(self, test, listener, session_ids): + events = listener.started_events + if not len(test["expectations"]): + return + + # Give a nicer message when there are missing or extra events + cmds = decode_raw([event.command for event in events]) + self.assertEqual(len(events), len(test["expectations"]), cmds) + for i, expectation in enumerate(test["expectations"]): + event_type = next(iter(expectation)) + event = events[i] + + # The tests substitute 42 for any number other than 0. + if event.command_name == "getMore" and event.command["getMore"]: + event.command["getMore"] = Int64(42) + elif event.command_name == "killCursors": + event.command["cursors"] = [Int64(42)] + elif event.command_name == "update": + # TODO: remove this once PYTHON-1744 is done. + # Add upsert and multi fields back into expectations. + updates = expectation[event_type]["command"]["updates"] + for update in updates: + update.setdefault("upsert", False) + update.setdefault("multi", False) + + # Replace afterClusterTime: 42 with actual afterClusterTime. + expected_cmd = expectation[event_type]["command"] + expected_read_concern = expected_cmd.get("readConcern") + if expected_read_concern is not None: + time = expected_read_concern.get("afterClusterTime") + if time == 42: + actual_time = event.command.get("readConcern", {}).get("afterClusterTime") + if actual_time is not None: + expected_read_concern["afterClusterTime"] = actual_time + + recovery_token = expected_cmd.get("recoveryToken") + if recovery_token == 42: + expected_cmd["recoveryToken"] = CompareType(dict) + + # Replace lsid with a name like "session0" to match test. + if "lsid" in event.command: + for name, lsid in session_ids.items(): + if event.command["lsid"] == lsid: + event.command["lsid"] = name + break + + for attr, expected in expectation[event_type].items(): + actual = getattr(event, attr) + expected = wrap_types(expected) + if isinstance(expected, dict): + for key, val in expected.items(): + if val is None: + if key in actual: + self.fail(f"Unexpected key [{key}] in {actual!r}") + elif key not in actual: + self.fail(f"Expected key [{key}] in {actual!r}") + else: + self.assertEqual( + val, decode_raw(actual[key]), f"Key [{key}] in {actual}" + ) + else: + self.assertEqual(actual, expected) + + def maybe_skip_scenario(self, test): + if test.get("skipReason"): + self.skipTest(test.get("skipReason")) + + def get_scenario_db_name(self, scenario_def): + """Allow subclasses to override a test's database name.""" + return scenario_def["database_name"] + + def get_scenario_coll_name(self, scenario_def): + """Allow subclasses to override a test's collection name.""" + return scenario_def["collection_name"] + + def get_outcome_coll_name(self, outcome, collection): + """Allow subclasses to override outcome collection.""" + return collection.name + + async def run_test_ops(self, sessions, collection, test): + """Added to allow retryable writes spec to override a test's + operation. + """ + await self.run_operations(sessions, collection, test["operations"]) + + def parse_client_options(self, opts): + """Allow encryption spec to override a clientOptions parsing.""" + return opts + + async def setup_scenario(self, scenario_def): + """Allow specs to override a test's setup.""" + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + documents = scenario_def["data"] + + # Setup the collection with as few majority writes as possible. + db = async_client_context.client.get_database(db_name) + coll_exists = bool(await db.list_collection_names(filter={"name": coll_name})) + if coll_exists: + await db[coll_name].delete_many({}) + # Only use majority wc only on the final write. + wc = WriteConcern(w="majority") + if documents: + db.get_collection(coll_name, write_concern=wc).insert_many(documents) + elif not coll_exists: + # Ensure collection exists. + await db.create_collection(coll_name, write_concern=wc) + + async def run_scenario(self, scenario_def, test): + self.maybe_skip_scenario(test) + + # Kill all sessions after each test with transactions to prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + for op in test["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addAsyncCleanup(self.kill_all_sessions) + break + await self.setup_scenario(scenario_def) + database_name = self.get_scenario_db_name(scenario_def) + collection_name = self.get_scenario_coll_name(scenario_def) + # SPEC-1245 workaround StaleDbVersion on distinct + for c in self.mongos_clients: + await c[database_name][collection_name].distinct("x") + + # Configure the fail point before creating the client. + if "failPoint" in test: + fp = test["failPoint"] + await self.set_fail_point(fp) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + listener = OvertCommandListener() + pool_listener = CMAPListener() + server_listener = ServerAndTopologyEventListener() + # Create a new client, to avoid interference from pooled sessions. + client_options = self.parse_client_options(test["clientOptions"]) + use_multi_mongos = test["useMultipleMongoses"] + host = None + if use_multi_mongos: + if async_client_context.load_balancer: + host = async_client_context.MULTI_MONGOS_LB_URI + elif async_client_context.is_mongos: + host = async_client_context.mongos_seeds() + client = await self.async_rs_client( + h=host, event_listeners=[listener, pool_listener, server_listener], **client_options + ) + self.scenario_client = client + self.listener = listener + self.pool_listener = pool_listener + self.server_listener = server_listener + + # Create session0 and session1. + sessions = {} + session_ids = {} + for i in range(2): + # Don't attempt to create sessions if they are not supported by + # the running server version. + if not async_client_context.sessions_enabled: + break + session_name = "session%d" % i + opts = camel_to_snake_args(test["sessionOptions"][session_name]) + if "default_transaction_options" in opts: + txn_opts = self.parse_options(opts["default_transaction_options"]) + txn_opts = client_session.TransactionOptions(**txn_opts) + opts["default_transaction_options"] = txn_opts + + s = client.start_session(**dict(opts)) + + sessions[session_name] = s + # Store lsid so we can access it after end_session, in check_events. + session_ids[session_name] = s.session_id + + self.addAsyncCleanup(end_sessions, sessions) + + collection = client[database_name][collection_name] + await self.run_test_ops(sessions, collection, test) + + await end_sessions(sessions) + + self.check_events(test, listener, session_ids) + + # Disable fail points. + if "failPoint" in test: + fp = test["failPoint"] + await self.set_fail_point( + {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + # Assert final state is expected. + outcome = test["outcome"] + expected_c = outcome.get("collection") + if expected_c is not None: + outcome_coll_name = self.get_outcome_coll_name(outcome, collection) + + # Read from the primary with local read concern to ensure causal + # consistency. + outcome_coll = async_client_context.client[collection.database.name].get_collection( + outcome_coll_name, + read_preference=ReadPreference.PRIMARY, + read_concern=ReadConcern("local"), + ) + actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list() + + # The expected data needs to be the left hand side here otherwise + # CompareType(Binary) doesn't work. + self.assertEqual(wrap_types(expected_c["data"]), actual_data) + + +def expect_any_error(op): + if isinstance(op, dict): + return op.get("error") + + return False + + +def expect_error_message(expected_result): + if isinstance(expected_result, dict): + return isinstance(expected_result["errorContains"], str) + + return False + + +def expect_error_code(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorCodeName"] + + return False + + +def expect_error_labels_contain(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorLabelsContain"] + + return False + + +def expect_error_labels_omit(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorLabelsOmit"] + + return False + + +def expect_timeout_error(expected_result): + if isinstance(expected_result, dict): + return expected_result["isTimeoutError"] + + return False + + +def expect_error(op): + expected_result = op.get("result") + return ( + expect_any_error(op) + or expect_error_message(expected_result) + or expect_error_code(expected_result) + or expect_error_labels_contain(expected_result) + or expect_error_labels_omit(expected_result) + or expect_timeout_error(expected_result) + ) + + +async def end_sessions(sessions): + for s in sessions.values(): + # Aborts the transaction if it's open. + await s.end_session() + + +def decode_raw(val): + """Decode RawBSONDocuments in the given container.""" + if isinstance(val, (list, abc.Mapping)): + return decode(encode({"v": val}))["v"] + return val + + +TYPES = { + "binData": Binary, + "long": Int64, + "int": int, + "string": str, + "objectId": ObjectId, + "object": dict, + "array": list, +} + + +def wrap_types(val): + """Support $$type assertion in command results.""" + if isinstance(val, list): + return [wrap_types(v) for v in val] + if isinstance(val, abc.Mapping): + typ = val.get("$$type") + if typ: + if isinstance(typ, str): + types = TYPES[typ] + else: + types = tuple(TYPES[t] for t in typ) + return CompareType(types) + d = {} + for key in val: + d[key] = wrap_types(val[key]) + return d + return val diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 95e580cef9..34e1c95ef2 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -16,14 +16,43 @@ from __future__ import annotations import asyncio +import functools import os import time -from test import client_context +import unittest +from collections import abc +from inspect import iscoroutinefunction +from test import IntegrationTest, client_context, client_knobs from test.helpers import ConcurrentRunner -from test.utils_shared import ScenarioDict - -from bson import json_util +from test.utils_shared import ( + CMAPListener, + CompareType, + EventListener, + OvertCommandListener, + ScenarioDict, + ServerAndTopologyEventListener, + camel_to_snake, + camel_to_snake_args, + parse_spec_options, + prepare_spec_arguments, +) +from typing import List + +from bson import ObjectId, decode, encode, json_util +from bson.binary import Binary +from bson.int64 import Int64 +from bson.son import SON +from gridfs import GridFSBucket +from gridfs.synchronous.grid_file import GridFSBucket +from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError from pymongo.lock import _cond_wait, _create_condition, _create_lock +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.results import BulkWriteResult, _WriteResult +from pymongo.synchronous import client_session +from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.cursor import Cursor +from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -190,3 +219,598 @@ def create_tests(self): self._create_tests() else: asyncio.run(self._create_tests()) + + +class SpecRunner(IntegrationTest): + mongos_clients: List + knobs: client_knobs + listener: EventListener + + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] + + # Speed up the tests by decreasing the heartbeat frequency. + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.targets = {} + self.listener = None # type: ignore + self.pool_listener = None + self.server_listener = None + self.maxDiff = None + + def tearDown(self) -> None: + self.knobs.disable() + + def set_fail_point(self, command_args): + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + self.configure_fail_point(client, command_args) + + def targeted_fail_point(self, session, fail_point): + """Run the targetedFailPoint test operation. + + Enable the fail point on the session's pinned mongos. + """ + clients = {c.address: c for c in self.mongos_clients} + client = clients[session._pinned_address] + self.configure_fail_point(client, fail_point) + self.addCleanup(self.set_fail_point, {"mode": "off"}) + + def assert_session_pinned(self, session): + """Run the assertSessionPinned test operation. + + Assert that the given session is pinned. + """ + self.assertIsNotNone(session._transaction.pinned_address) + + def assert_session_unpinned(self, session): + """Run the assertSessionUnpinned test operation. + + Assert that the given session is not pinned. + """ + self.assertIsNone(session._pinned_address) + self.assertIsNone(session._transaction.pinned_address) + + def assert_collection_exists(self, database, collection): + """Run the assertCollectionExists test operation.""" + db = self.client[database] + self.assertIn(collection, db.list_collection_names()) + + def assert_collection_not_exists(self, database, collection): + """Run the assertCollectionNotExists test operation.""" + db = self.client[database] + self.assertNotIn(collection, db.list_collection_names()) + + def assert_index_exists(self, database, collection, index): + """Run the assertIndexExists test operation.""" + coll = self.client[database][collection] + self.assertIn(index, [doc["name"] for doc in coll.list_indexes()]) + + def assert_index_not_exists(self, database, collection, index): + """Run the assertIndexNotExists test operation.""" + coll = self.client[database][collection] + self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()]) + + def wait(self, ms): + """Run the "wait" test operation.""" + time.sleep(ms / 1000.0) + + def assertErrorLabelsContain(self, exc, expected_labels): + labels = [l for l in expected_labels if exc.has_error_label(l)] + self.assertEqual(labels, expected_labels) + + def assertErrorLabelsOmit(self, exc, omit_labels): + for label in omit_labels: + self.assertFalse( + exc.has_error_label(label), msg=f"error labels should not contain {label}" + ) + + def kill_all_sessions(self): + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + try: + client.admin.command("killAllSessions", []) + except (OperationFailure, AutoReconnect): + # "operation was interrupted" by killing the command's + # own session. + # On 8.0+ killAllSessions sometimes returns a network error. + pass + + def check_command_result(self, expected_result, result): + # Only compare the keys in the expected result. + filtered_result = {} + for key in expected_result: + try: + filtered_result[key] = result[key] + except KeyError: + pass + self.assertEqual(filtered_result, expected_result) + + # TODO: factor the following function with test_crud.py. + def check_result(self, expected_result, result): + if isinstance(result, _WriteResult): + for res in expected_result: + prop = camel_to_snake(res) + # SPEC-869: Only BulkWriteResult has upserted_count. + if prop == "upserted_count" and not isinstance(result, BulkWriteResult): + if result.upserted_id is not None: + upserted_count = 1 + else: + upserted_count = 0 + self.assertEqual(upserted_count, expected_result[res], prop) + elif prop == "inserted_ids": + # BulkWriteResult does not have inserted_ids. + if isinstance(result, BulkWriteResult): + self.assertEqual(len(expected_result[res]), result.inserted_count) + else: + # InsertManyResult may be compared to [id1] from the + # crud spec or {"0": id1} from the retryable write spec. + ids = expected_result[res] + if isinstance(ids, dict): + ids = [ids[str(i)] for i in range(len(ids))] + + self.assertEqual(ids, result.inserted_ids, prop) + elif prop == "upserted_ids": + # Convert indexes from strings to integers. + ids = expected_result[res] + expected_ids = {} + for str_index in ids: + expected_ids[int(str_index)] = ids[str_index] + self.assertEqual(expected_ids, result.upserted_ids, prop) + else: + self.assertEqual(getattr(result, prop), expected_result[res], prop) + + return True + else: + + def _helper(expected_result, result): + if isinstance(expected_result, abc.Mapping): + for i in expected_result.keys(): + self.assertEqual(expected_result[i], result[i]) + + elif isinstance(expected_result, list): + for i, k in zip(expected_result, result): + _helper(i, k) + else: + self.assertEqual(expected_result, result) + + _helper(expected_result, result) + return None + + def get_object_name(self, op): + """Allow subclasses to override handling of 'object' + + Transaction spec says 'object' is required. + """ + return op["object"] + + @staticmethod + def parse_options(opts): + return parse_spec_options(opts) + + def run_operation(self, sessions, collection, operation): + original_collection = collection + name = camel_to_snake(operation["name"]) + if name == "run_command": + name = "command" + elif name == "download_by_name": + name = "open_download_stream_by_name" + elif name == "download": + name = "open_download_stream" + elif name == "map_reduce": + self.skipTest("PyMongo does not support mapReduce") + elif name == "count": + self.skipTest("PyMongo does not support count") + + database = collection.database + collection = database.get_collection(collection.name) + if "collectionOptions" in operation: + collection = collection.with_options( + **self.parse_options(operation["collectionOptions"]) + ) + + object_name = self.get_object_name(operation) + if object_name == "gridfsbucket": + # Only create the GridFSBucket when we need it (for the gridfs + # retryable reads tests). + obj = GridFSBucket(database, bucket_name=collection.name) + else: + objects = { + "client": database.client, + "database": database, + "collection": collection, + "testRunner": self, + } + objects.update(sessions) + obj = objects[object_name] + + # Combine arguments with options and handle special cases. + arguments = operation.get("arguments", {}) + arguments.update(arguments.pop("options", {})) + self.parse_options(arguments) + + cmd = getattr(obj, name) + + with_txn_callback = functools.partial( + self.run_operations, sessions, original_collection, in_with_transaction=True + ) + prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) + + if name == "run_on_thread": + args = {"sessions": sessions, "collection": collection} + args.update(arguments) + arguments = args + + if not _IS_SYNC and iscoroutinefunction(cmd): + result = cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) + # Cleanup open change stream cursors. + if name == "watch": + self.addCleanup(result.close) + + if name == "aggregate": + if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: + # Read from the primary to ensure causal consistency. + out = collection.database.get_collection( + arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY + ) + return out.find() + if "download" in name: + result = Binary(result.read()) + + if isinstance(result, Cursor) or isinstance(result, CommandCursor): + return result.to_list() + + return result + + def allowable_errors(self, op): + """Allow encryption spec to override expected error classes.""" + return (PyMongoError,) + + def _run_op(self, sessions, collection, op, in_with_transaction): + expected_result = op.get("result") + if expect_error(op): + with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: + self.run_operation(sessions, collection, op.copy()) + exc = context.exception + if expect_error_message(expected_result): + if isinstance(exc, BulkWriteError): + errmsg = str(exc.details).lower() + else: + errmsg = str(exc).lower() + self.assertIn(expected_result["errorContains"].lower(), errmsg) + if expect_error_code(expected_result): + self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName")) + if expect_error_labels_contain(expected_result): + self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"]) + if expect_error_labels_omit(expected_result): + self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"]) + if expect_timeout_error(expected_result): + self.assertIsInstance(exc, PyMongoError) + if not exc.timeout: + # Re-raise the exception for better diagnostics. + raise exc + + # Reraise the exception if we're in the with_transaction + # callback. + if in_with_transaction: + raise context.exception + else: + result = self.run_operation(sessions, collection, op.copy()) + if "result" in op: + if op["name"] == "runCommand": + self.check_command_result(expected_result, result) + else: + self.check_result(expected_result, result) + + def run_operations(self, sessions, collection, ops, in_with_transaction=False): + for op in ops: + self._run_op(sessions, collection, op, in_with_transaction) + + # TODO: factor with test_command_monitoring.py + def check_events(self, test, listener, session_ids): + events = listener.started_events + if not len(test["expectations"]): + return + + # Give a nicer message when there are missing or extra events + cmds = decode_raw([event.command for event in events]) + self.assertEqual(len(events), len(test["expectations"]), cmds) + for i, expectation in enumerate(test["expectations"]): + event_type = next(iter(expectation)) + event = events[i] + + # The tests substitute 42 for any number other than 0. + if event.command_name == "getMore" and event.command["getMore"]: + event.command["getMore"] = Int64(42) + elif event.command_name == "killCursors": + event.command["cursors"] = [Int64(42)] + elif event.command_name == "update": + # TODO: remove this once PYTHON-1744 is done. + # Add upsert and multi fields back into expectations. + updates = expectation[event_type]["command"]["updates"] + for update in updates: + update.setdefault("upsert", False) + update.setdefault("multi", False) + + # Replace afterClusterTime: 42 with actual afterClusterTime. + expected_cmd = expectation[event_type]["command"] + expected_read_concern = expected_cmd.get("readConcern") + if expected_read_concern is not None: + time = expected_read_concern.get("afterClusterTime") + if time == 42: + actual_time = event.command.get("readConcern", {}).get("afterClusterTime") + if actual_time is not None: + expected_read_concern["afterClusterTime"] = actual_time + + recovery_token = expected_cmd.get("recoveryToken") + if recovery_token == 42: + expected_cmd["recoveryToken"] = CompareType(dict) + + # Replace lsid with a name like "session0" to match test. + if "lsid" in event.command: + for name, lsid in session_ids.items(): + if event.command["lsid"] == lsid: + event.command["lsid"] = name + break + + for attr, expected in expectation[event_type].items(): + actual = getattr(event, attr) + expected = wrap_types(expected) + if isinstance(expected, dict): + for key, val in expected.items(): + if val is None: + if key in actual: + self.fail(f"Unexpected key [{key}] in {actual!r}") + elif key not in actual: + self.fail(f"Expected key [{key}] in {actual!r}") + else: + self.assertEqual( + val, decode_raw(actual[key]), f"Key [{key}] in {actual}" + ) + else: + self.assertEqual(actual, expected) + + def maybe_skip_scenario(self, test): + if test.get("skipReason"): + self.skipTest(test.get("skipReason")) + + def get_scenario_db_name(self, scenario_def): + """Allow subclasses to override a test's database name.""" + return scenario_def["database_name"] + + def get_scenario_coll_name(self, scenario_def): + """Allow subclasses to override a test's collection name.""" + return scenario_def["collection_name"] + + def get_outcome_coll_name(self, outcome, collection): + """Allow subclasses to override outcome collection.""" + return collection.name + + def run_test_ops(self, sessions, collection, test): + """Added to allow retryable writes spec to override a test's + operation. + """ + self.run_operations(sessions, collection, test["operations"]) + + def parse_client_options(self, opts): + """Allow encryption spec to override a clientOptions parsing.""" + return opts + + def setup_scenario(self, scenario_def): + """Allow specs to override a test's setup.""" + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + documents = scenario_def["data"] + + # Setup the collection with as few majority writes as possible. + db = client_context.client.get_database(db_name) + coll_exists = bool(db.list_collection_names(filter={"name": coll_name})) + if coll_exists: + db[coll_name].delete_many({}) + # Only use majority wc only on the final write. + wc = WriteConcern(w="majority") + if documents: + db.get_collection(coll_name, write_concern=wc).insert_many(documents) + elif not coll_exists: + # Ensure collection exists. + db.create_collection(coll_name, write_concern=wc) + + def run_scenario(self, scenario_def, test): + self.maybe_skip_scenario(test) + + # Kill all sessions after each test with transactions to prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + for op in test["operations"]: + name = op["name"] + if name == "startTransaction" or name == "withTransaction": + self.addCleanup(self.kill_all_sessions) + break + self.setup_scenario(scenario_def) + database_name = self.get_scenario_db_name(scenario_def) + collection_name = self.get_scenario_coll_name(scenario_def) + # SPEC-1245 workaround StaleDbVersion on distinct + for c in self.mongos_clients: + c[database_name][collection_name].distinct("x") + + # Configure the fail point before creating the client. + if "failPoint" in test: + fp = test["failPoint"] + self.set_fail_point(fp) + self.addCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + listener = OvertCommandListener() + pool_listener = CMAPListener() + server_listener = ServerAndTopologyEventListener() + # Create a new client, to avoid interference from pooled sessions. + client_options = self.parse_client_options(test["clientOptions"]) + use_multi_mongos = test["useMultipleMongoses"] + host = None + if use_multi_mongos: + if client_context.load_balancer: + host = client_context.MULTI_MONGOS_LB_URI + elif client_context.is_mongos: + host = client_context.mongos_seeds() + client = self.rs_client( + h=host, event_listeners=[listener, pool_listener, server_listener], **client_options + ) + self.scenario_client = client + self.listener = listener + self.pool_listener = pool_listener + self.server_listener = server_listener + + # Create session0 and session1. + sessions = {} + session_ids = {} + for i in range(2): + # Don't attempt to create sessions if they are not supported by + # the running server version. + if not client_context.sessions_enabled: + break + session_name = "session%d" % i + opts = camel_to_snake_args(test["sessionOptions"][session_name]) + if "default_transaction_options" in opts: + txn_opts = self.parse_options(opts["default_transaction_options"]) + txn_opts = client_session.TransactionOptions(**txn_opts) + opts["default_transaction_options"] = txn_opts + + s = client.start_session(**dict(opts)) + + sessions[session_name] = s + # Store lsid so we can access it after end_session, in check_events. + session_ids[session_name] = s.session_id + + self.addCleanup(end_sessions, sessions) + + collection = client[database_name][collection_name] + self.run_test_ops(sessions, collection, test) + + end_sessions(sessions) + + self.check_events(test, listener, session_ids) + + # Disable fail points. + if "failPoint" in test: + fp = test["failPoint"] + self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"}) + + # Assert final state is expected. + outcome = test["outcome"] + expected_c = outcome.get("collection") + if expected_c is not None: + outcome_coll_name = self.get_outcome_coll_name(outcome, collection) + + # Read from the primary with local read concern to ensure causal + # consistency. + outcome_coll = client_context.client[collection.database.name].get_collection( + outcome_coll_name, + read_preference=ReadPreference.PRIMARY, + read_concern=ReadConcern("local"), + ) + actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() + + # The expected data needs to be the left hand side here otherwise + # CompareType(Binary) doesn't work. + self.assertEqual(wrap_types(expected_c["data"]), actual_data) + + +def expect_any_error(op): + if isinstance(op, dict): + return op.get("error") + + return False + + +def expect_error_message(expected_result): + if isinstance(expected_result, dict): + return isinstance(expected_result["errorContains"], str) + + return False + + +def expect_error_code(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorCodeName"] + + return False + + +def expect_error_labels_contain(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorLabelsContain"] + + return False + + +def expect_error_labels_omit(expected_result): + if isinstance(expected_result, dict): + return expected_result["errorLabelsOmit"] + + return False + + +def expect_timeout_error(expected_result): + if isinstance(expected_result, dict): + return expected_result["isTimeoutError"] + + return False + + +def expect_error(op): + expected_result = op.get("result") + return ( + expect_any_error(op) + or expect_error_message(expected_result) + or expect_error_code(expected_result) + or expect_error_labels_contain(expected_result) + or expect_error_labels_omit(expected_result) + or expect_timeout_error(expected_result) + ) + + +def end_sessions(sessions): + for s in sessions.values(): + # Aborts the transaction if it's open. + s.end_session() + + +def decode_raw(val): + """Decode RawBSONDocuments in the given container.""" + if isinstance(val, (list, abc.Mapping)): + return decode(encode({"v": val}))["v"] + return val + + +TYPES = { + "binData": Binary, + "long": Int64, + "int": int, + "string": str, + "objectId": ObjectId, + "object": dict, + "array": list, +} + + +def wrap_types(val): + """Support $$type assertion in command results.""" + if isinstance(val, list): + return [wrap_types(v) for v in val] + if isinstance(val, abc.Mapping): + typ = val.get("$$type") + if typ: + if isinstance(typ, str): + types = TYPES[typ] + else: + types = tuple(TYPES[t] for t in typ) + return CompareType(types) + d = {} + for key in val: + d[key] = wrap_types(val[key]) + return d + return val From 77182fee428631fbb38be182064f6d9ad344a339 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 9 Mar 2026 11:01:12 -0700 Subject: [PATCH 5/5] PYTHON-5114 Revert spec runner changes --- test/asynchronous/utils_spec_runner.py | 9 +++------ test/utils_spec_runner.py | 9 +++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index f099eee12c..63e7e9e150 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -621,14 +621,11 @@ async def setup_scenario(self, scenario_def): async def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions after each test with transactions to prevent an open + # Kill all sessions before and after each test to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. - for op in test["operations"]: - name = op["name"] - if name == "startTransaction" or name == "withTransaction": - self.addAsyncCleanup(self.kill_all_sessions) - break + await self.kill_all_sessions() + self.addAsyncCleanup(self.kill_all_sessions) await self.setup_scenario(scenario_def) database_name = self.get_scenario_db_name(scenario_def) collection_name = self.get_scenario_coll_name(scenario_def) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 34e1c95ef2..9bf155e8f3 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -621,14 +621,11 @@ def setup_scenario(self, scenario_def): def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) - # Kill all sessions after each test with transactions to prevent an open + # Kill all sessions before and after each test to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. - for op in test["operations"]: - name = op["name"] - if name == "startTransaction" or name == "withTransaction": - self.addCleanup(self.kill_all_sessions) - break + self.kill_all_sessions() + self.addCleanup(self.kill_all_sessions) self.setup_scenario(scenario_def) database_name = self.get_scenario_db_name(scenario_def) collection_name = self.get_scenario_coll_name(scenario_def)