From 40c18bbde7d21367da1ffd58d5b101c805037e66 Mon Sep 17 00:00:00 2001 From: Ben McKerry <110857332+bmckerry@users.noreply.github.com> Date: Fri, 5 Jun 2026 11:38:26 -0400 Subject: [PATCH] fix(TaskProducer): bounded queue of pending futures --- clients/python/src/taskbroker_client/constants.py | 7 +++++++ .../src/taskbroker_client/worker/producer.py | 15 ++++++++++++--- clients/python/tests/worker/test_producer.py | 7 +++++++ clients/python/tests/worker/test_worker.py | 2 +- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/clients/python/src/taskbroker_client/constants.py b/clients/python/src/taskbroker_client/constants.py index ce79b934..66f113d5 100644 --- a/clients/python/src/taskbroker_client/constants.py +++ b/clients/python/src/taskbroker_client/constants.py @@ -70,6 +70,13 @@ to drain pending produce futures on shutdown before sending SIGKILL. """ +TASK_PRODUCER_MAX_PENDING_FUTURES = 10_000 +""" +Maximum number of pending futures that can be in the TaskProducer module's +`_pending_futures` list. This list is a global, so is shared between all instances +of TaskProducer. +""" + class CompressionType(Enum): """ diff --git a/clients/python/src/taskbroker_client/worker/producer.py b/clients/python/src/taskbroker_client/worker/producer.py index 2e741d9b..d1e0a3b3 100644 --- a/clients/python/src/taskbroker_client/worker/producer.py +++ b/clients/python/src/taskbroker_client/worker/producer.py @@ -1,3 +1,4 @@ +from collections import deque from collections.abc import Callable from concurrent.futures import Future from typing import Any, Sequence @@ -6,11 +7,16 @@ from arroyo.backends.kafka import KafkaPayload from arroyo.types import BrokerValue, Topic +from taskbroker_client.constants import TASK_PRODUCER_MAX_PENDING_FUTURES from taskbroker_client.types import ProducerProtocol # This is global as TaskWorker needs to be able to call TaskProducer.collect_futures() # without a reference to a task's specific instance of TaskProducer. -_pending_futures: set[ProducerFuture[BrokerValue[KafkaPayload]]] = set() +# Has a max_len to prevent unbounded future growth if TaskProducer.collect_futures() +# is never called. +_pending_futures: deque[ProducerFuture[BrokerValue[KafkaPayload]]] = deque( + maxlen=TASK_PRODUCER_MAX_PENDING_FUTURES +) class TaskProducer: @@ -21,6 +27,9 @@ class TaskProducer: producer futures tracked by TaskProducer, and will only register the task activation as a success if all producer futures from that activation were successful. Otherwise, the activation will be retried. + + Args: + producer_factory: Callable that returns a producer object. """ def __init__(self, producer_factory: Callable[[], ProducerProtocol]) -> None: @@ -33,13 +42,13 @@ def _get(self) -> ProducerProtocol: return self._inner_producer def track_future(self, future: ProducerFuture[BrokerValue[KafkaPayload]]) -> None: - _pending_futures.add(future) + _pending_futures.append(future) @staticmethod def collect_futures() -> set[ProducerFuture[BrokerValue[KafkaPayload]]]: futures = _pending_futures.copy() _pending_futures.clear() - return futures + return set(futures) def produce( self, diff --git a/clients/python/tests/worker/test_producer.py b/clients/python/tests/worker/test_producer.py index 10a16a1a..c3cb7ec6 100644 --- a/clients/python/tests/worker/test_producer.py +++ b/clients/python/tests/worker/test_producer.py @@ -80,3 +80,10 @@ def callback(future: Future[BrokerValue[KafkaPayload]]) -> None: with pytest.raises(RuntimeError, match="SimpleProducerFuture"): producer.produce(Topic("test"), make_kafka_payload(), callbacks=[callback]) + + +def test_pending_futures_max_len() -> None: + producer = TaskProducer(partial(get_dummy_producer, use_simple_futures=True)) + for _ in range(10001): + producer.produce(Topic("test"), make_kafka_payload()) + assert len(_pending_futures) == 10000 diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 44e4e20a..bf583bce 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -1452,7 +1452,7 @@ def test_child_process_clears_pending_futures_when_task_fails( ) -> None: leftover_future: Future[BrokerValue[KafkaPayload]] = Future() leftover_future.set_result(_make_broker_value()) - _pending_futures.add(leftover_future) + _pending_futures.append(leftover_future) assert len(_pending_futures) == 1 todo: queue.Queue[InflightTaskActivation] = queue.Queue()