From baeae2838d1f9a2f0dd28ed3767b410749bf96a2 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 7 May 2026 01:58:39 -0400 Subject: [PATCH] cluster: add control-connection query fallback Add an opt-in control-connection fallback for application queries when the driver cannot populate normal node pools, which happens in deployments that expose the cluster through a non-broadcast IP address such as a TCP proxy or a node public IP. In that mode the driver can still execute queries over the single control connection, but throughput is poor and connection churn increases the chance of request errors. This option is intentionally disabled by default and should not be used in production. Also propagate keyspace updates on the fallback path so USE keeps the control connection in sync. Tests: - tests/unit/test_cluster.py::ClusterTest::test_set_keyspace_for_all_pools_reports_all_errors - tests/unit/test_response_future.py::ResponseFutureTests::test_control_connection_fallback_updates_connection_keyspace --- cassandra/cluster.py | 233 +++++++++++++-- docs/api/cassandra/cluster.rst | 5 + .../integration/cqlengine/model/test_model.py | 10 +- tests/integration/standard/conftest.py | 1 + .../test_control_connection_query_fallback.py | 115 +++++++ tests/unit/test_cluster.py | 77 ++++- tests/unit/test_response_future.py | 281 +++++++++++++++++- 7 files changed, 689 insertions(+), 33 deletions(-) create mode 100644 tests/integration/standard/test_control_connection_query_fallback.py diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 483843c2a6..1181c6f686 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -28,6 +28,7 @@ from copy import copy from functools import partial, reduce, wraps from itertools import groupby, count, chain +import enum import json import logging from typing import Any, Dict, Optional, Union, Tuple @@ -514,8 +515,9 @@ def __init__(self, load_balancing_policy=None, retry_policy=None, class ProfileManager(object): - def __init__(self): + def __init__(self, pools_allowed: bool=True): self.profiles = dict() + self.pools_allowed = pools_allowed def _profiles_without_explicit_lbps(self): names = (profile_name for @@ -527,6 +529,8 @@ def _profiles_without_explicit_lbps(self): ) def distance(self, host): + if not self.pools_allowed: + return HostDistance.IGNORED distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \ HostDistance.LOCAL if HostDistance.LOCAL in distances else \ @@ -542,10 +546,14 @@ def check_supported(self): p.load_balancing_policy.check_supported() def on_up(self, host): + if not self.pools_allowed: + return for p in self.profiles.values(): p.load_balancing_policy.on_up(host) def on_down(self, host): + if not self.pools_allowed: + return for p in self.profiles.values(): p.load_balancing_policy.on_down(host) @@ -619,6 +627,31 @@ class _ConfigMode(object): PROFILES = 2 +class ControlConnectionQueryFallback(enum.Enum): + """ + Controls how application queries use the control connection when node pools + are unavailable. + + ``Disabled`` requires a usable node pool for application queries. If the + driver cannot establish one during session startup, it raises + :class:`NoHostAvailable`. + + ``Fallback`` still attempts to create node pools, but allows application + queries to fall back to the control connection when no usable node pool is + available. Session startup is allowed to proceed even if the initial pool + attempts all fail. + + ``SkipPoolCreation`` disables node-pool creation for the session and uses + the control-connection fallback path for application queries. + + The fallback path is not used for requests targeted to an explicit host. + """ + + Disabled = "Disabled" + Fallback = "Fallback" + SkipPoolCreation = "SkipPoolCreation" + + class Cluster(object): """ The main class to use when interacting with a Cassandra cluster. @@ -939,6 +972,16 @@ def default_retry_policy(self, policy): If set to :const:`None`, there will be no timeout for these queries. """ + allow_control_connection_query_fallback: ControlConnectionQueryFallback = ControlConnectionQueryFallback.Disabled + """ + Controls whether application queries may fall back to the control connection. + + ``Disabled`` keeps the old behavior. + ``Fallback`` enables control-connection fallback when no usable node pools exist. + ``SkipPoolCreation`` skips node-pool creation and uses the control connection fallback path. + This fallback is still not used for requests targeted to an explicit host. + """ + idle_heartbeat_interval = 30 """ Interval, in seconds, on which to heartbeat idle connections. This helps @@ -1225,7 +1268,8 @@ def __init__(self, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, application_info:Optional[ApplicationInfoBase]=None, - client_routes_config:Optional[ClientRoutesConfig]=None + client_routes_config:Optional[ClientRoutesConfig]=None, + allow_control_connection_query_fallback:Optional[ControlConnectionQueryFallback]=ControlConnectionQueryFallback.Disabled ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1243,6 +1287,10 @@ def __init__(self, if port < 1 or port > 65535: raise ValueError("Invalid port number (%s) (1-65535)" % port) + if not isinstance(allow_control_connection_query_fallback, ControlConnectionQueryFallback): + raise TypeError( + "allow_control_connection_query_fallback must be a ControlConnectionQueryFallback value") + if connection_class is not None: self.connection_class = connection_class @@ -1404,7 +1452,8 @@ def __init__(self, else: self.timestamp_generator = MonotonicTimestampGenerator() - self.profile_manager = ProfileManager() + self.profile_manager = ProfileManager( + pools_allowed=allow_control_connection_query_fallback != ControlConnectionQueryFallback.SkipPoolCreation) self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile( self.load_balancing_policy, self.default_retry_policy, @@ -1473,6 +1522,7 @@ def __init__(self, self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout + self.allow_control_connection_query_fallback = allow_control_connection_query_fallback self.metadata_request_timeout = self.control_connection_timeout if metadata_request_timeout is None else metadata_request_timeout self.idle_heartbeat_interval = idle_heartbeat_interval self.idle_heartbeat_timeout = idle_heartbeat_timeout @@ -1815,7 +1865,8 @@ def get_all_pools(self): return pools def is_shard_aware(self): - return bool(self.get_all_pools()[0].host.sharding_info) + pools = self.get_all_pools() + return bool(pools and pools[0].host.sharding_info) def shard_aware_stats(self): if self.is_shard_aware(): @@ -1920,7 +1971,7 @@ def on_up(self, host): """ Intended for internal use only. """ - if self.is_shutdown: + if self.is_shutdown or self.allow_control_connection_query_fallback == ControlConnectionQueryFallback.SkipPoolCreation: return log.debug("Waiting to acquire lock for handling up status of node %s", host) @@ -2028,7 +2079,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): """ Intended for internal use only. """ - if self.is_shutdown: + if self.is_shutdown or self.allow_control_connection_query_fallback == ControlConnectionQueryFallback.SkipPoolCreation: return with host.lock: @@ -2633,20 +2684,24 @@ def __init__(self, cluster, hosts, keyspace=None): # create connection pools in parallel self._initial_connect_futures = set() - for host in hosts: - future = self.add_or_renew_pool(host, is_host_addition=False) - if future: - self._initial_connect_futures.add(future) - - futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) - while futures.not_done and not any(f.result() for f in futures.done): - futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) - - if not any(f.result() for f in self._initial_connect_futures): - msg = "Unable to connect to any servers" - if self.keyspace: - msg += " using keyspace '%s'" % self.keyspace - raise NoHostAvailable(msg, [h.address for h in hosts]) + fallback_mode = self.cluster.allow_control_connection_query_fallback + if fallback_mode is not ControlConnectionQueryFallback.SkipPoolCreation: + for host in hosts: + future = self.add_or_renew_pool(host, is_host_addition=False) + if future: + self._initial_connect_futures.add(future) + + futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) + while futures.not_done and not any(f.result() for f in futures.done): + futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) + + # Only Disabled requires an initial pool to come up. + if not any(f.result() for f in self._initial_connect_futures) and \ + fallback_mode is ControlConnectionQueryFallback.Disabled: + msg = "Unable to connect to any servers" + if self.keyspace: + msg += " using keyspace '%s'" % self.keyspace + raise NoHostAvailable(msg, [h.address for h in hosts]) self.session_id = uuid.uuid4() @@ -3245,6 +3300,9 @@ def add_or_renew_pool(self, host, is_host_addition): """ For internal use only. """ + if self.cluster.allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation: + return None + distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None @@ -3315,6 +3373,9 @@ def update_created_pools(self): For internal use only. """ + if self.cluster.allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation: + return set() + futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) @@ -4650,6 +4711,7 @@ class ResponseFuture(object): _spec_execution_plan = NoSpeculativeExecutionPlan() _continuous_paging_session = None _host = None + _control_connection_query_attempted = False _TABLET_ROUTING_CTYPE = None _warned_timeout = False @@ -4670,6 +4732,7 @@ def __init__(self, session, message, query, timeout, metrics=None, prepared_stat self._callback_lock = Lock() self._start_time = start_time or time.time() self._host = host + self._control_connection_query_attempted = False self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan self._make_query_plan() self._event = Event() @@ -4748,11 +4811,22 @@ def _on_timeout(self, _attempts=0): self._connection.orphaned_threshold_reached = True pool.return_connection(self._connection, stream_was_orphaned=True) + elif self._connection.is_control_connection: + with self._connection.lock: + self._connection.orphaned_request_ids.add(self._req_id) + if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold: + self._connection.orphaned_threshold_reached = True errors = self._errors if not errors: if self.is_schema_agreed: - key = str(self._current_host.endpoint) if self._current_host else 'no host queried before timeout' + if self._current_host is None: + key = 'no host queried before timeout' + elif self._connection is not None and self._connection.is_control_connection: + control_host = self.session.cluster.get_control_connection_host() + key = str(control_host.endpoint) if control_host is not None else str(self._connection.endpoint) + else: + key = str(self._current_host.endpoint) errors = {key: "Client request timeout. See Session.execute[_async](timeout)"} else: connection = self.session.cluster.control_connection._connection @@ -4810,14 +4884,110 @@ def send_request(self, error_no_hosts=True): self._on_timeout() return True if error_no_hosts: + if self._fallback_to_control_connection(): + req_id = self._query_control_connection() + if req_id is not None: + self._req_id = req_id + return True + self._set_final_exception(NoHostAvailable( "Unable to complete the operation against any hosts", self._errors)) return False + def _has_usable_node_pool(self): + try: + pools = tuple(self.session._pools.values()) + except (AttributeError, TypeError): + return False + + return any(pool and not pool.is_shutdown for pool in pools) + + def _fallback_to_control_connection(self): + fallback_mode = self.session.cluster.allow_control_connection_query_fallback + if fallback_mode is ControlConnectionQueryFallback.Disabled: + return False + if self._host or self._control_connection_query_attempted: + return False + if fallback_mode is ControlConnectionQueryFallback.SkipPoolCreation: + return True + return not self._has_usable_node_pool() + + def _borrow_control_connection(self, connection): + with connection.lock: + if connection.in_flight >= connection.max_request_id: + raise NoConnectionsAvailable("All request IDs are currently in use") + connection.in_flight += 1 + return connection.get_request_id() + + def _release_control_connection_request(self, connection, request_id): + with connection.lock: + connection.in_flight -= 1 + connection.request_ids.append(request_id) + connection._requests.pop(request_id, None) + + def _handle_control_connection_response(self, connection, cb, response): + with connection.lock: + connection.in_flight -= 1 + cb(response) + + def _query_control_connection(self, message=None, cb=None, connection=None, host=None): + self._control_connection_query_attempted = True + + if message is None: + message = self.message + + if connection is None: + control_connection = self.session.cluster.control_connection + connection = control_connection._connection if control_connection else None + if not connection: + self._errors['control connection'] = ConnectionException("Control connection is not connected") + return None + + if host is None: + host = self.session.cluster.get_control_connection_host() or connection.endpoint + self._current_host = host + + request_id = None + request_sent = False + try: + request_id = self._borrow_control_connection(connection) + self._connection = connection + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + if cb is None: + cb = partial(self._set_result, host, connection, None) + cb = partial(self._handle_control_connection_response, connection, cb) + + log.debug("No usable node pools; falling back to control connection for host %s", host) + self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, + encoder=self._protocol_handler.encode_message, + decoder=self._protocol_handler.decode_message, + result_metadata=result_meta) + request_sent = True + self.attempted_hosts.append(host) + return request_id + except NoConnectionsAvailable as exc: + log.debug("Control connection is at capacity") + self._errors[host] = exc + except ConnectionBusy as exc: + log.debug("Control connection is busy") + self._errors[host] = exc + except Exception as exc: + log.debug("Error querying control connection", exc_info=True) + self._errors[host] = exc + if self._metrics is not None: + self._metrics.on_connection_error() + finally: + if request_id is not None and not request_sent: + self._release_control_connection_request(connection, request_id) + + return None + def _query(self, host, message=None, cb=None): if message is None: message = self.message + self._control_connection_query_attempted = False + pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") @@ -4928,12 +5098,17 @@ def start_fetching_next_page(self): self._event.clear() self._final_result = _NOT_SET self._final_exception = None + self._control_connection_query_attempted = False self._start_timer() self.send_request() def _reprepare(self, prepare_message, host, connection, pool): cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) - request_id = self._query(host, prepare_message, cb=cb) + if pool is None and connection is not None and connection.is_control_connection: + request_id = self._query_control_connection(prepare_message, cb=cb, + connection=connection, host=host) + else: + request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host self.send_request() @@ -4972,6 +5147,8 @@ def _set_result(self, host, connection, pool, response): if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) + if connection is not None: + connection.keyspace = response.new_keyspace # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on # all connections in this session, otherwise the event @@ -5148,10 +5325,13 @@ def _execute_after_prepare(self, host, connection, pool, response): new_metadata_id = response.result_metadata_id if new_metadata_id is not None: self.prepared_statement.result_metadata_id = new_metadata_id - + # use self._query to re-use the same host and # at the same time properly borrow the connection - request_id = self._query(host) + if pool is None and connection is not None and connection.is_control_connection: + request_id = self._query_control_connection(connection=connection, host=host) + else: + request_id = self._query(host) if request_id is None: # this host errored out, move on to the next self.send_request() @@ -5264,6 +5444,11 @@ def _retry_task(self, reuse_connection, host): # to retry the operation return + if self._control_connection_query_attempted: + self._control_connection_query_attempted = False + self.send_request() + return + if reuse_connection and self._query(host) is not None: return diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index de8518d271..44b7b63f67 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -48,6 +48,8 @@ Clusters and Sessions .. autoattribute:: control_connection_timeout + .. autoattribute:: allow_control_connection_query_fallback + .. autoattribute:: idle_heartbeat_interval .. autoattribute:: idle_heartbeat_timeout @@ -106,6 +108,9 @@ Clusters and Sessions .. automethod:: set_meta_refresh_enabled +.. autoclass:: ControlConnectionQueryFallback + :members: + .. autoclass:: ExecutionProfile (load_balancing_policy=, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=10.0, row_factory=, speculative_execution_policy=None) :members: :exclude-members: consistency_level diff --git a/tests/integration/cqlengine/model/test_model.py b/tests/integration/cqlengine/model/test_model.py index cafe6ae9c9..98d71993fd 100644 --- a/tests/integration/cqlengine/model/test_model.py +++ b/tests/integration/cqlengine/model/test_model.py @@ -259,10 +259,8 @@ class SensitiveModel(Model): rows[-1] rows[-1:] - # ignore DeprecationWarning('The loop argument is deprecated since Python 3.8, and scheduled for removal in Python 3.10.') - relevant_warnings = [warn for warn in w if "The loop argument is deprecated" not in str(warn.message)] + warning_messages = [str(warn.message) for warn in w] - assert "__table_name_case_sensitive__ will be removed in 4.0." in str(relevant_warnings[0].message) - assert "__table_name_case_sensitive__ will be removed in 4.0." in str(relevant_warnings[1].message) - assert "ModelQuerySet indexing with negative indices support will be removed in 4.0." in str(relevant_warnings[2].message) - assert "ModelQuerySet slicing with negative indices support will be removed in 4.0." in str(relevant_warnings[3].message) + assert sum("__table_name_case_sensitive__ will be removed in 4.0." in message for message in warning_messages) == 2 + assert sum("ModelQuerySet indexing with negative indices support will be removed in 4.0." in message for message in warning_messages) == 1 + assert sum("ModelQuerySet slicing with negative indices support will be removed in 4.0." in message for message in warning_messages) == 1 diff --git a/tests/integration/standard/conftest.py b/tests/integration/standard/conftest.py index 3adaf371b0..9934cfcbbb 100644 --- a/tests/integration/standard/conftest.py +++ b/tests/integration/standard/conftest.py @@ -37,6 +37,7 @@ "test_ip_change": 4, "test_authentication": 4, "test_authentication_misconfiguration": 4, + "test_control_connection_query_fallback": 4, "test_custom_cluster": 4, "test_query": 4, # Group 5: tablets (destructive — decommissions a node) diff --git a/tests/integration/standard/test_control_connection_query_fallback.py b/tests/integration/standard/test_control_connection_query_fallback.py new file mode 100644 index 0000000000..e64763a72c --- /dev/null +++ b/tests/integration/standard/test_control_connection_query_fallback.py @@ -0,0 +1,115 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from cassandra.cluster import ControlConnectionQueryFallback, NoHostAvailable + +from tests.integration import USE_CASS_EXTERNAL, TestCluster, local, remove_cluster, use_cluster + + +_CLUSTER_NAME = "control_connection_query_fallback" +_UNREACHABLE_BROADCAST_RPC_ADDRESS = "127.255.255.1" + + +def setup_module(): + if USE_CASS_EXTERNAL: + return + + remove_cluster() + + ccm_cluster = use_cluster(_CLUSTER_NAME, [1], start=False) + ccm_cluster.nodes["node1"].set_configuration_options(values={ + "broadcast_rpc_address": _UNREACHABLE_BROADCAST_RPC_ADDRESS, + }) + ccm_cluster.start(wait_for_binary_proto=True, wait_other_notice=True) + + +def teardown_module(): + if USE_CASS_EXTERNAL: + return + + remove_cluster() + + +@local +class ControlConnectionQueryFallbackIntegrationTests(unittest.TestCase): + + def setUp(self): + self.cluster = None + + def tearDown(self): + if self.cluster is not None: + self.cluster.shutdown() + + def _assert_unreachable_broadcast_rpc_metadata(self): + hosts = self.cluster.metadata.all_hosts() + assert len(hosts) == 1 + + host = hosts[0] + assert host.broadcast_rpc_address == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert host.endpoint.address == _UNREACHABLE_BROADCAST_RPC_ADDRESS + return host + + def test_disabled_raises_when_broadcast_rpc_address_is_unreachable(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Disabled, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + with pytest.raises(NoHostAvailable): + self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert self.cluster.control_connection._connection is not None + assert self.cluster.get_all_pools() == [] + + def test_fallback_executes_queries_when_broadcast_rpc_address_is_unreachable(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + session = self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert session._initial_connect_futures + assert list(session.get_pools()) == [] + + row = session.execute( + "SELECT release_version, rpc_address FROM system.local WHERE key='local'").one() + assert str(row.rpc_address) == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert row.release_version + + def test_no_node_pool_fallback_executes_queries_without_creating_pools(self): + self.cluster = TestCluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation, + connect_timeout=1, + monitor_reporting_enabled=False, + ) + + session = self.cluster.connect() + + self._assert_unreachable_broadcast_rpc_metadata() + assert session._initial_connect_futures == set() + assert list(session.get_pools()) == [] + + row = session.execute( + "SELECT release_version, rpc_address FROM system.local WHERE key='local'").one() + assert str(row.rpc_address) == _UNREACHABLE_BROADCAST_RPC_ADDRESS + assert row.release_version diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index b6f2da5372..3d55bc1860 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from concurrent.futures import Future import logging import socket from types import SimpleNamespace @@ -22,9 +23,9 @@ from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion -from cassandra.cluster import _Scheduler, Session, Cluster, ResultSet, SchemaAgreementScope, default_lbp_factory, \ +from cassandra.cluster import _Scheduler, Session, Cluster, ResultSet, SchemaAgreementScope, ControlConnectionQueryFallback, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT -from cassandra.connection import ConnectionBusy +from cassandra.connection import ConnectionBusy, ConnectionException from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -186,6 +187,52 @@ def test_port_range(self): with pytest.raises(ValueError): cluster = Cluster(contact_points=['127.0.0.1'], port=invalid_port) + def test_control_connection_query_fallback_modes(self): + assert Cluster().allow_control_connection_query_fallback is ControlConnectionQueryFallback.Disabled + with pytest.raises(TypeError): + Cluster(allow_control_connection_query_fallback=False) + with pytest.raises(TypeError): + Cluster(allow_control_connection_query_fallback=True) + assert ( + Cluster(allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback) + .allow_control_connection_query_fallback + is ControlConnectionQueryFallback.Fallback + ) + assert Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation + ).allow_control_connection_query_fallback is ControlConnectionQueryFallback.SkipPoolCreation + + def test_control_connection_query_fallback_no_node_pool_mode_skips_pool_creation(self): + cluster = Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.SkipPoolCreation, + monitor_reporting_enabled=False, + ) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + with patch.object(Session, "add_or_renew_pool") as mocked_add_or_renew_pool: + session = Session(cluster, [host]) + + mocked_add_or_renew_pool.assert_not_called() + assert session._initial_connect_futures == set() + assert session._pools == {} + assert session.update_created_pools() == set() + + def test_control_connection_query_fallback_fallback_tolerates_empty_initial_pools(self): + cluster = Cluster( + allow_control_connection_query_fallback=ControlConnectionQueryFallback.Fallback, + monitor_reporting_enabled=False, + ) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + future = Future() + future.set_result(False) + + with patch.object(Session, "add_or_renew_pool", return_value=future) as mocked_add_or_renew_pool: + session = Session(cluster, [host]) + + mocked_add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + assert session._initial_connect_futures == {future} + assert session._pools == {} + def test_compression_autodisabled_without_libraries(self): with patch.dict('cassandra.cluster.locally_supported_compressions', {}, clear=True): with patch('cassandra.cluster.log') as patched_logger: @@ -551,6 +598,32 @@ def test_wait_for_schema_agreement_rejects_unknown_scope(self, *_): with pytest.raises(ValueError): session.wait_for_schema_agreement(wait_time=1, scope='planet') + @mock_session_pools + def test_set_keyspace_for_all_pools_reports_all_errors(self, *_): + cluster = Cluster() + session = Session( + cluster, + [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4())], + ) + + pool1 = Mock(host='host1') + pool2 = Mock(host='host2') + keyspace_error = ConnectionException("boom") + + pool1._set_keyspace_for_all_conns.side_effect = ( + lambda keyspace, callback: callback(pool1, [keyspace_error]) + ) + pool2._set_keyspace_for_all_conns.side_effect = ( + lambda keyspace, callback: callback(pool2, []) + ) + session._pools = {'host1': pool1, 'host2': pool2} + + callback = Mock() + session._set_keyspace_for_all_pools('ks', callback) + + callback.assert_called_once() + assert callback.call_args.args[0] == {'host1': [keyspace_error]} + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index dd7fa75045..9673b0d634 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -19,7 +19,7 @@ from unittest.mock import Mock, MagicMock, ANY from cassandra import ConsistencyLevel, Unavailable, SchemaTargetType, SchemaChangeType, OperationTimedOut -from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion +from cassandra.cluster import Session, ResponseFuture, NoHostAvailable, ProtocolVersion, ControlConnectionQueryFallback from cassandra.connection import Connection, ConnectionException from cassandra.protocol import (ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, ResultMessage, QueryMessage, @@ -41,6 +41,7 @@ def make_basic_session(self): s = Mock(spec=Session) s.row_factory = lambda col_names, rows: [(col_names, rows)] s.cluster.control_connection._tablets_routing_v1 = False + s.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Disabled return s def make_pool(self): @@ -49,6 +50,22 @@ def make_pool(self): pool.borrow_connection.return_value = [Mock(), Mock()] return pool + def make_control_connection(self): + connection = Mock(spec=Connection) + connection.endpoint = 'control-host' + connection.lock = RLock() + connection.in_flight = 0 + connection.max_request_id = 100 + connection.request_ids = deque() + connection._requests = {} + connection.orphaned_request_ids = set() + connection.orphaned_threshold = 75 + connection.orphaned_threshold_reached = False + connection.is_control_connection = True + connection.get_request_id.return_value = 7 + connection.send_msg.return_value = 128 + return connection + def make_session(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] @@ -391,6 +408,268 @@ def test_all_pools_shutdown(self): with pytest.raises(NoHostAvailable): rf.result() + def test_control_connection_fallback_disabled_by_default(self): + session = self.make_basic_session() + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + rf = self.make_response_future(session) + rf.send_request() + + connection.send_msg.assert_not_called() + with pytest.raises(NoHostAvailable): + rf.result() + + def test_control_connection_fallback_updates_connection_keyspace(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + + def set_keyspace_for_all_pools(keyspace, callback): + session.keyspace = keyspace + callback({}) + + session._set_keyspace_for_all_pools.side_effect = set_keyspace_for_all_pools + + connection = self.make_control_connection() + connection.keyspace = 'oldks' + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + result = Mock(spec=ResultMessage, kind=RESULT_KIND_SET_KEYSPACE, new_keyspace='newks') + connection.send_msg.call_args[1]['cb'](result) + + assert connection.keyspace == 'newks' + assert session.keyspace == 'newks' + assert rf.result().current_rows == [] + + def test_control_connection_fallback_when_no_usable_pools(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.SkipPoolCreation + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + connection.send_msg.assert_called_once_with( + rf.message, 7, cb=ANY, encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, result_metadata=[]) + assert connection.in_flight == 1 + assert rf.attempted_hosts == [control_host] + + cb = connection.send_msg.call_args[1]['cb'] + expected_result = (object(), object()) + cb(self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_retries_after_server_error(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + connection.get_request_id.side_effect = [7, 8] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + first_response = Mock(spec=ServerError, info={}) + first_response.summary = 'boom' + first_response.to_exception.return_value = first_response + connection.send_msg.call_args[1]['cb'](first_response) + + rf.session.cluster.scheduler.schedule.assert_called_once_with(ANY, rf._retry_task, False, control_host) + + # The retry decision must come from the future state, not the live connection reference. + rf._connection = Mock(is_control_connection=False) + + rf._retry_task(False, control_host) + + assert connection.send_msg.call_count == 2 + assert connection.send_msg.call_args_list[1][0][0] is rf.message + assert connection.send_msg.call_args_list[1][0][1] == 8 + assert rf.attempted_hosts == [control_host, control_host] + + expected_result = (object(), object()) + connection.send_msg.call_args_list[1][1]['cb']( + self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_fetches_next_page(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + connection.get_request_id.side_effect = [7, 8] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + assert rf.send_request() + + first_response = self.make_mock_response(['col'], [(1,)]) + first_response.paging_state = b'next-page' + connection.send_msg.call_args[1]['cb'](first_response) + + assert rf.result().current_rows == [(['col'], [(1,)])] + assert rf.has_more_pages + + rf.start_fetching_next_page() + + assert connection.send_msg.call_count == 2 + assert connection.send_msg.call_args_list[1][0][0] is rf.message + assert connection.send_msg.call_args_list[1][0][1] == 8 + assert rf.message.paging_state == b'next-page' + + second_response = self.make_mock_response(['col'], [(2,)]) + connection.send_msg.call_args_list[1][1]['cb'](second_response) + + assert connection.in_flight == 0 + assert rf.result().current_rows == [(['col'], [(2,)])] + + def test_control_connection_fallback_reprepares_prepared_statement(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster.protocol_version = ProtocolVersion.V4 + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + session.submit.side_effect = lambda fn, *args, **kwargs: fn(*args, **kwargs) + + query_id = b'a' * 16 + prepared_statement = Mock( + query_id=query_id, + query_string="SELECT * FROM foobar", + keyspace="FooKeyspace", + result_metadata=[], + result_metadata_id=None) + session.cluster._prepared_statements = {query_id: prepared_statement} + + connection = self.make_control_connection() + connection.keyspace = "FooKeyspace" + connection.get_request_id.side_effect = [7, 8, 9] + session.cluster.control_connection._connection = connection + control_host = Mock(endpoint=connection.endpoint) + session.cluster.get_control_connection_host.return_value = control_host + + rf = self.make_response_future(session) + rf.prepared_statement = prepared_statement + assert rf.send_request() + + missing = Mock(spec=PreparedQueryNotFound, info=query_id) + connection.send_msg.call_args_list[0][1]['cb'](missing) + + assert connection.send_msg.call_count == 2 + prepare_message = connection.send_msg.call_args_list[1][0][0] + assert isinstance(prepare_message, PrepareMessage) + assert prepare_message.query == "SELECT * FROM foobar" + assert connection.send_msg.call_args_list[1][0][1] == 8 + + prepared_response = Mock( + spec=ResultMessage, + kind=RESULT_KIND_PREPARED, + query_id=query_id, + column_metadata=[], + result_metadata_id=None) + connection.send_msg.call_args_list[1][1]['cb'](prepared_response) + + assert connection.send_msg.call_count == 3 + assert connection.send_msg.call_args_list[2][0][0] is rf.message + assert connection.send_msg.call_args_list[2][0][1] == 9 + + expected_result = (['col'], [(1,)]) + connection.send_msg.call_args_list[2][1]['cb']( + self.make_mock_response(expected_result[0], expected_result[1])) + + assert connection.in_flight == 0 + assert rf.result()[0] == expected_result + + def test_control_connection_fallback_not_used_when_pool_can_serve(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + pool = Mock(is_shutdown=False) + pool.borrow_connection.side_effect = NoConnectionsAvailable() + session._pools = {'ip1': pool} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + rf = self.make_response_future(session) + rf.send_request() + + connection.send_msg.assert_not_called() + with pytest.raises(NoHostAvailable): + rf.result() + + def test_control_connection_fallback_orphans_stream_on_timeout(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1'] + session._pools = {} + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + def send_msg(message, request_id, cb, **kwargs): + connection._requests[request_id] = (cb, kwargs.get('decoder'), kwargs.get('result_metadata')) + return 128 + + connection.send_msg.side_effect = send_msg + + rf = self.make_response_future(session) + rf.send_request() + rf._on_timeout() + + assert 7 in connection.orphaned_request_ids + assert connection.in_flight == 1 + with pytest.raises(OperationTimedOut): + rf.result() + + def test_control_connection_fallback_timeout_without_metadata_host_uses_connection_endpoint(self): + session = self.make_basic_session() + session.cluster.allow_control_connection_query_fallback = ControlConnectionQueryFallback.Fallback + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [] + session._pools = {} + session.cluster.get_control_connection_host.return_value = None + connection = self.make_control_connection() + session.cluster.control_connection._connection = connection + + def send_msg(message, request_id, cb, **kwargs): + connection._requests[request_id] = (cb, kwargs.get('decoder'), kwargs.get('result_metadata')) + return 128 + + connection.send_msg.side_effect = send_msg + + rf = self.make_response_future(session) + assert rf.send_request() + rf._on_timeout() + + with pytest.raises(OperationTimedOut) as exc_info: + rf.result() + + assert exc_info.value.errors == { + 'control-host': 'Client request timeout. See Session.execute[_async](timeout)' + } + def test_first_pool_shutdown(self): session = self.make_basic_session() session.cluster._default_load_balancing_policy.make_query_plan.return_value = ['ip1', 'ip2']