From d8530e8163d36872a1e45b1876232c2fe9955b2b Mon Sep 17 00:00:00 2001 From: Andrzej Jackowski Date: Wed, 29 Apr 2026 17:11:04 +0200 Subject: [PATCH 1/2] Prevent connection pool replacement race When pool creation races for the same host, a slower attempt can overwrite a pool that another thread already published and close connections with in-flight requests. Capture the previous pool before connection setup, then compare that state under the session lock before publishing the new pool. If another thread changed the pool, discard the stale pool instead of replacing the current one. Keep pool removals behind the same lock so the check observes all writers. Fixes: #317 --- cassandra/cluster.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..f48f36cfcf 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3241,6 +3241,9 @@ def add_or_renew_pool(self, host, is_host_addition): return None def run_add_or_renew_pool(): + with self._lock: + previous = self._pools.get(host) + try: new_pool = HostConnection(host, distance, self) except AuthenticationFailed as auth_exc: @@ -3256,7 +3259,6 @@ def run_add_or_renew_pool(): host, conn_exc, is_host_addition, expect_host_to_be_down=True) return False - previous = self._pools.get(host) with self._lock: while new_pool._keyspace != self.keyspace: self._lock.release() @@ -3276,7 +3278,20 @@ def callback(pool, errors): self._lock.acquire() return False self._lock.acquire() - self._pools[host] = new_pool + + pool_unchanged = self._pools.get(host) is previous + if not pool_unchanged: + # Another concurrent add_or_renew_pool changed this host + # while we were creating ours. Don't replace the existing + # pool because doing so would kill in-flight queries. + log.debug("Pool for host %s was already replaced by another " + "thread, discarding new pool", host) + else: + self._pools[host] = new_pool + + if not pool_unchanged: + new_pool.shutdown() + return True log.debug("Added pool for host %s to session", host) if previous: @@ -3287,7 +3302,8 @@ def callback(pool, errors): return self.submit(run_add_or_renew_pool) def remove_pool(self, host): - pool = self._pools.pop(host, None) + with self._lock: + pool = self._pools.pop(host, None) if pool: log.debug("Removed connection pool for %r", host) return self.submit(pool.shutdown) From b1ae845c27f77bf4702b0a93371e0e1b0844c153 Mon Sep 17 00:00:00 2001 From: Andrzej Jackowski Date: Wed, 29 Apr 2026 17:11:04 +0200 Subject: [PATCH 2/2] test: cover pool replacement race from #317 Add a deterministic unit test for the case where another thread publishes a pool while a slower add attempt is still constructing its pool. This guards against closing in-flight connections by replacing the pool that should remain current. Refs: #317 --- tests/unit/test_cluster.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index a4f0ebc4d3..a25606c188 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,6 +15,7 @@ import logging import socket +import threading from unittest.mock import patch, Mock import uuid @@ -23,7 +24,7 @@ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT -from cassandra.pool import Host +from cassandra.pool import Host, HostConnection from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools @@ -339,6 +340,39 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + +class SessionPoolRaceTest(unittest.TestCase): + def test_concurrent_add_or_renew_pool_no_double_replace(self): + """Reproduces https://github.com/scylladb/python-driver/issues/317.""" + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + session = Session.__new__(Session) + session.submit = lambda fn: Mock(result=lambda timeout=None: fn()) + session.keyspace = None + session._lock = threading.RLock() + session._pools = {} + session._profile_manager = Mock() + session._profile_manager.distance.return_value = HostDistance.LOCAL + + winner_pool = Mock() + created_pools = [] + + def fake_host_connection_init(pool, *_): + pool._keyspace = session.keyspace + pool.shutdown = Mock() + created_pools.append(pool) + log.info("Publishing competing pool while replacement pool is being created") + with session._lock: + session._pools[host] = winner_pool + + with patch.object(HostConnection, '__init__', fake_host_connection_init): + result = session.add_or_renew_pool(host, is_host_addition=True).result() + + assert result is True + assert session._pools[host] is winner_pool + created_pools[0].shutdown.assert_called_once() + winner_pool.shutdown.assert_not_called() + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self):