From 1e3d0c44a728f07ecfd2f954910d4d3ce15bf1ab Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Wed, 10 Jun 2026 15:52:38 -0400 Subject: [PATCH 1/4] feat: Add a BatchPushTaskWorker for batched updates to the broker This worker will read out a batch of updates, and send them to the server all at once using the batched activations service. The default for the TaskWorker is now to handle lists of updates, but the previous workers will still only update one activation at a time. --- clients/python/pyproject.toml | 2 +- clients/python/src/examples/cli.py | 26 ++- .../src/taskbroker_client/worker/__init__.py | 4 +- .../src/taskbroker_client/worker/client.py | 108 ---------- .../taskbroker_client/worker/push_clients.py | 192 ++++++++++++++++++ .../src/taskbroker_client/worker/worker.py | 175 +++++++++++++--- clients/python/tests/worker/test_worker.py | 88 ++++++++ uv.lock | 8 +- 8 files changed, 458 insertions(+), 145 deletions(-) create mode 100644 clients/python/src/taskbroker_client/worker/push_clients.py diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index f8a84ff3..9ee0ba7f 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" dependencies = [ "sentry-arroyo>=2.38.7", "sentry-sdk[http2]>=2.43.0", - "sentry-protos>=0.15.0", + "sentry-protos>=0.26.1", "confluent_kafka>=2.3.0", "cronsim>=2.6", "grpcio>=1.67.1", diff --git a/clients/python/src/examples/cli.py b/clients/python/src/examples/cli.py index 4c9fdc60..687cba12 100644 --- a/clients/python/src/examples/cli.py +++ b/clients/python/src/examples/cli.py @@ -76,18 +76,36 @@ def scheduler() -> None: @click.option( "--push-mode", help="Whether to run in PUSH or PULL mode.", default=False, is_flag=True ) +@click.option( + "--batch-push-mode", help="Whether to run in BATCH PUSH mode.", default=False, is_flag=True +) @click.option( "--grpc-port", help="Port for the gRPC server to listen on.", default=50052, type=int, ) -def worker(rpc_host: str, concurrency: int, push_mode: bool, grpc_port: int) -> None: - from taskbroker_client.worker import PushTaskWorker, TaskWorker +def worker( + rpc_host: str, concurrency: int, push_mode: bool, batch_push_mode: bool, grpc_port: int +) -> None: + from taskbroker_client.worker import BatchPushTaskWorker, PushTaskWorker, TaskWorker click.echo("Starting worker") - if push_mode: - worker: PushTaskWorker | TaskWorker = PushTaskWorker( + if batch_push_mode: + worker: PushTaskWorker | TaskWorker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service=rpc_host, + max_child_task_count=100, + concurrency=concurrency, + child_tasks_queue_maxsize=concurrency * 2, + result_queue_maxsize=concurrency * 2, + rebalance_after=32, + processing_pool_name="examples", + process_type="forkserver", + grpc_port=grpc_port, + ) + elif push_mode: + worker = PushTaskWorker( app_module="examples.app:app", broker_service=rpc_host, max_child_task_count=100, diff --git a/clients/python/src/taskbroker_client/worker/__init__.py b/clients/python/src/taskbroker_client/worker/__init__.py index d94f2b62..79f70886 100644 --- a/clients/python/src/taskbroker_client/worker/__init__.py +++ b/clients/python/src/taskbroker_client/worker/__init__.py @@ -1,3 +1,3 @@ -from .worker import PushTaskWorker, TaskWorker +from .worker import BatchPushTaskWorker, PushTaskWorker, TaskWorker -__all__ = ("TaskWorker", "PushTaskWorker") +__all__ = ("TaskWorker", "PushTaskWorker", "BatchPushTaskWorker") diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index 356a4f75..baf609d6 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -493,111 +493,3 @@ def update_task( receive_timestamp=time.monotonic(), ) return None - - -class PushTaskbrokerClient: - """ - Taskworker RPC client wrapper - - Push brokers are a deployment so they don't need to be connected to individually. There is one service provided - that works for all the brokers. - """ - - def __init__( - self, - service: str, - application: str, - metrics: MetricsBackend, - health_check_settings: HealthCheckSettings | None = None, - rpc_secret: str | None = None, - grpc_config: str | None = None, - ) -> None: - self._application = application - self._service = service - self._rpc_secret = rpc_secret - self._metrics = metrics - - self._grpc_options: list[tuple[str, Any]] = [ - ("grpc.max_receive_message_length", MAX_ACTIVATION_SIZE) - ] - if grpc_config: - self._grpc_options.append(("grpc.service_config", grpc_config)) - - logger.info( - "taskworker.push_client.start", - extra={"service": service, "options": self._grpc_options}, - ) - - self._stub = self._connect_to_host(service) - - self._health_check_settings = health_check_settings - self._timestamp_since_touch_lock = threading.Lock() - self._timestamp_since_touch = 0.0 - - def _emit_health_check(self) -> None: - if self._health_check_settings is None: - return - - with self._timestamp_since_touch_lock: - cur_time = time.time() - if ( - cur_time - self._timestamp_since_touch - < self._health_check_settings.touch_interval_sec - ): - return - - self._health_check_settings.file_path.touch() - self._metrics.incr( - "taskworker.client.health_check.touched", - ) - self._timestamp_since_touch = cur_time - - def _connect_to_host(self, host: str) -> ConsumerServiceStub: - logger.info("taskworker.push_client.connect", extra={"host": host}) - channel = grpc.insecure_channel(host, options=self._grpc_options) - secrets = parse_rpc_secret_list(self._rpc_secret) - if secrets: - channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) - return ConsumerServiceStub(channel) - - def emit_health_check(self) -> None: - self._emit_health_check() - - def update_task( - self, - processing_result: ProcessingResult, - ) -> None: - """ - Update the status for a given task activation. - """ - self._emit_health_check() - - request = SetTaskStatusRequest( - id=processing_result.task_id, - status=processing_result.status, - fetch_next_task=None, - max_attempts=processing_result.max_attempts, - delay_on_retry=processing_result.delay_on_retry, - ) - - retries = 0 - exception = None - while retries < 3: - try: - with self._metrics.timer( - "taskworker.update_task.rpc", tags={"service": self._service} - ): - self._stub.SetTaskStatus(request) - exception = None - break - except grpc.RpcError as err: - exception = err - self._metrics.incr( - "taskworker.client.rpc_error", - tags={"method": "SetTaskStatus", "status": err.code().name}, - ) - finally: - retries += 1 - - if exception: - raise exception diff --git a/clients/python/src/taskbroker_client/worker/push_clients.py b/clients/python/src/taskbroker_client/worker/push_clients.py new file mode 100644 index 00000000..5a13b6de --- /dev/null +++ b/clients/python/src/taskbroker_client/worker/push_clients.py @@ -0,0 +1,192 @@ +import logging +import threading +import time +from typing import TYPE_CHECKING, Any + +import grpc +from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( + SetBatchActivationStatusRequest, + SetTaskStatusRequest, +) +from sentry_protos.taskbroker.v1.taskbroker_pb2_grpc import ConsumerServiceStub + +from taskbroker_client.metrics import MetricsBackend +from taskbroker_client.types import ProcessingResult +from taskbroker_client.worker.client import ( + MAX_ACTIVATION_SIZE, + HealthCheckSettings, + RequestSignatureInterceptor, + parse_rpc_secret_list, +) + +if TYPE_CHECKING: + ServerInterceptor = grpc.ServerInterceptor[Any, Any] +else: + ServerInterceptor = grpc.ServerInterceptor + +logger = logging.getLogger(__name__) + + +class PushTaskbrokerClient: + """ + Taskworker RPC client wrapper + + Push brokers are a deployment so they don't need to be connected to individually. There is one service provided + that works for all the brokers. + """ + + def __init__( + self, + service: str, + application: str, + metrics: MetricsBackend, + health_check_settings: HealthCheckSettings | None = None, + rpc_secret: str | None = None, + grpc_config: str | None = None, + ) -> None: + self._application = application + self._service = service + self._rpc_secret = rpc_secret + self._metrics = metrics + + self._grpc_options: list[tuple[str, Any]] = [ + ("grpc.max_receive_message_length", MAX_ACTIVATION_SIZE) + ] + if grpc_config: + self._grpc_options.append(("grpc.service_config", grpc_config)) + + logger.info( + "taskworker.push_client.start", + extra={"service": service, "options": self._grpc_options}, + ) + + self._stub = self._connect_to_host(service) + + self._health_check_settings = health_check_settings + self._timestamp_since_touch_lock = threading.Lock() + self._timestamp_since_touch = 0.0 + + def _emit_health_check(self) -> None: + if self._health_check_settings is None: + return + + with self._timestamp_since_touch_lock: + cur_time = time.time() + if ( + cur_time - self._timestamp_since_touch + < self._health_check_settings.touch_interval_sec + ): + return + + self._health_check_settings.file_path.touch() + self._metrics.incr( + "taskworker.client.health_check.touched", + ) + self._timestamp_since_touch = cur_time + + def _connect_to_host(self, host: str) -> ConsumerServiceStub: + logger.info("taskworker.push_client.connect", extra={"host": host}) + channel = grpc.insecure_channel(host, options=self._grpc_options) + secrets = parse_rpc_secret_list(self._rpc_secret) + if secrets: + channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) + return ConsumerServiceStub(channel) + + def emit_health_check(self) -> None: + self._emit_health_check() + + def update_tasks(self, processing_results: list[ProcessingResult]) -> None: + for processing_result in processing_results: + self._update_task_single(processing_result) + + def _update_task_single( + self, + processing_result: ProcessingResult, + ) -> None: + """ + Update the status for a given task activation. + """ + self._emit_health_check() + + request = SetTaskStatusRequest( + id=processing_result.task_id, + status=processing_result.status, + fetch_next_task=None, + max_attempts=processing_result.max_attempts, + delay_on_retry=processing_result.delay_on_retry, + ) + + retries = 0 + exception = None + while retries < 3: + try: + with self._metrics.timer( + "taskworker.update_task.rpc", tags={"service": self._service} + ): + self._stub.SetTaskStatus(request) + exception = None + break + except grpc.RpcError as err: + exception = err + self._metrics.incr( + "taskworker.client.rpc_error", + tags={"method": "SetTaskStatus", "status": err.code().name}, + ) + finally: + retries += 1 + + if exception: + raise exception + + +class BatchPushTaskbrokerClient(PushTaskbrokerClient): + """ + Taskworker RPC client wrapper + + Push brokers are a deployment so they don't need to be connected to individually. There is one service provided + that works for all the brokers. This client pushes batches of activation updates. + """ + + def update_tasks( + self, + processing_results: list[ProcessingResult], + ) -> None: + """ + Update the status for a given task activation. + """ + self._emit_health_check() + + request = SetBatchActivationStatusRequest( + updates=[ + SetTaskStatusRequest( + id=processing_result.task_id, + status=processing_result.status, + fetch_next_task=None, + max_attempts=processing_result.max_attempts, + delay_on_retry=processing_result.delay_on_retry, + ) + for processing_result in processing_results + ] + ) + + retries = 0 + exception = None + while retries < 3: + try: + with self._metrics.timer( + "taskworker.update_task_batch.rpc", tags={"service": self._service} + ): + self._stub.SetBatchActivationStatus(request) + exception = None + break + except grpc.RpcError as err: + exception = err + self._metrics.incr( + "taskworker.client.rpc_error", + tags={"method": "SetBatchActivationStatus", "status": err.code().name}, + ) + finally: + retries += 1 + + if exception: + raise exception diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index f69ecc4b..45c84d4c 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -31,15 +31,16 @@ MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE, WORKER_CHILD_JOIN_TIMEOUT_SEC, ) +from taskbroker_client.metrics import MetricsBackend from taskbroker_client.types import InflightTaskActivation, ProcessingResult from taskbroker_client.worker.client import ( HealthCheckSettings, HostTemporarilyUnavailable, - PushTaskbrokerClient, RequestSignatureServerInterceptor, TaskbrokerClient, parse_rpc_secret_list, ) +from taskbroker_client.worker.push_clients import BatchPushTaskbrokerClient, PushTaskbrokerClient from taskbroker_client.worker.workerchild import child_process if TYPE_CHECKING: @@ -138,6 +139,7 @@ def __init__( health_check_file_path: str | None = None, health_check_sec_per_touch: float = DEFAULT_WORKER_HEALTH_CHECK_SEC_PER_TOUCH, grpc_port: int = 50052, + update_in_batches: bool = False, ) -> None: app = import_app(app_module) @@ -153,7 +155,7 @@ def __init__( self.worker_pool = TaskWorkerProcessingPool( app_module=app_module, mp_context=self._mp_context, - send_result_fn=self._send_result, + send_result_fn=self._send_results, max_child_task_count=max_child_task_count, concurrency=concurrency, child_tasks_queue_maxsize=child_tasks_queue_maxsize, @@ -161,11 +163,12 @@ def __init__( processing_pool_name=processing_pool_name, pod_name=pod_name, process_type=process_type, + update_in_batches=update_in_batches, ) logger.info("Running in PUSH mode") - self.client = PushTaskbrokerClient( + self.client = self._create_client( service=broker_service, application=app.name, metrics=app.metrics, @@ -193,12 +196,34 @@ def __init__( self._grpc_port = grpc_port self._grpc_secrets = parse_rpc_secret_list(app.config["rpc_secret"]) - def _send_result( - self, result: ProcessingResult, is_draining: bool = False + def _create_client( + self, + service: str, + application: str, + metrics: MetricsBackend, + health_check_settings: HealthCheckSettings | None = None, + rpc_secret: str | None = None, + grpc_config: str | None = None, + ) -> PushTaskbrokerClient: + return PushTaskbrokerClient( + service=service, + application=application, + metrics=metrics, + health_check_settings=health_check_settings, + rpc_secret=rpc_secret, + grpc_config=grpc_config, + ) + + def _send_results( + self, results: list[ProcessingResult], is_draining: bool = False ) -> InflightTaskActivation | None: """ Send a result to the broker. If the set has failed before, sleep briefly before retrying. """ + assert ( + len(results) == 1 + ), "Only one result can be sent at a time with the regular push client" + result = results[0] self._metrics.distribution( "taskworker.worker.complete_duration", time.monotonic() - result.receive_timestamp, @@ -217,7 +242,7 @@ def _send_result( self._grpc_sync_event.wait(self._setstatus_backoff_seconds) try: - self.client.update_task(result) + self.client.update_tasks([result]) self._setstatus_backoff_seconds = 0 return None except grpc.RpcError as e: @@ -365,6 +390,85 @@ def shutdown(self) -> None: self.worker_pool.shutdown() +class BatchPushTaskWorker(PushTaskWorker): + def __init__(self, *args: Any, **kwargs: Any) -> None: + assert ( + kwargs["update_in_batches"] is True + ), "BatchPushTaskWorker must be initialized with update_in_batches=True" + super().__init__(*args, **kwargs) + + def _create_client( + self, + service: str, + application: str, + metrics: MetricsBackend, + health_check_settings: HealthCheckSettings | None = None, + rpc_secret: str | None = None, + grpc_config: str | None = None, + ) -> PushTaskbrokerClient: + return BatchPushTaskbrokerClient( + service=service, + application=application, + metrics=metrics, + health_check_settings=health_check_settings, + rpc_secret=rpc_secret, + grpc_config=grpc_config, + ) + + def _send_results( + self, results: list[ProcessingResult], is_draining: bool = False + ) -> InflightTaskActivation | None: + """ + Send a result to the broker. If the set has failed before, sleep briefly before retrying. + """ + for result in results: + self._metrics.distribution( + "taskworker.worker.complete_duration", + time.monotonic() - result.receive_timestamp, + tags={"processing_pool": self._processing_pool_name}, + ) + self._metrics.distribution( + "taskworker.worker.update_status_batch_size", + len(results), + tags={"processing_pool": self._processing_pool_name}, + ) + + logger.debug( + "taskworker.send_update_task_batch.batch_sent", + extra={ + "results": [result.task_id for result in results], + "processing_pool": self._processing_pool_name, + }, + ) + # Use the shutdown_event as a sleep mechanism + self._grpc_sync_event.wait(self._setstatus_backoff_seconds) + + try: + self.client.update_tasks(results) + self._setstatus_backoff_seconds = 0 + return None + except grpc.RpcError as e: + self._setstatus_backoff_seconds = min(self._setstatus_backoff_seconds + 1, 10) + logger.warning( + "taskworker.send_update_task_batch.failed", + extra={"results": [result.task_id for result in results], "error": e}, + ) + if e.code() != grpc.StatusCode.NOT_FOUND: + # If the task was not found, we can't update it, so we should just return None + raise RequeueException(f"Failed to update task batch: {e}") + except HostTemporarilyUnavailable as e: + self._setstatus_backoff_seconds = min( + self._setstatus_backoff_seconds + 4, MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE + ) + logger.info( + "taskworker.send_update_task_batch.temporarily_unavailable", + extra={"task_id": result.task_id, "error": str(e)}, + ) + raise RequeueException(f"Failed to update task: {e}") + + return None + + class TaskWorker: """ A TaskWorker fetches tasks from a taskworker RPC host and handles executing task activations. @@ -406,7 +510,7 @@ def __init__( self.worker_pool = TaskWorkerProcessingPool( app_module=app_module, mp_context=self._mp_context, - send_result_fn=self._send_result, + send_result_fn=self._send_results, max_child_task_count=max_child_task_count, concurrency=concurrency, child_tasks_queue_maxsize=child_tasks_queue_maxsize, @@ -484,19 +588,22 @@ def _add_task(self) -> bool: return False - def _send_result( - self, result: ProcessingResult, is_draining: bool = False + def _send_results( + self, results: list[ProcessingResult], is_draining: bool = False ) -> InflightTaskActivation | None: """ Send a result to the broker and conditionally fetch an additional task. Return a boolean indicating whether the result was sent successfully. """ + assert ( + len(results) == 1 + ), "Only one result can be sent at a time with the regular pull client" self._metrics.distribution( "taskworker.worker.complete_duration", - time.monotonic() - result.receive_timestamp, + time.monotonic() - results[0].receive_timestamp, tags={"processing_pool": self._processing_pool_name}, ) fetch_next = None if is_draining else FetchNextTask(namespace=self._namespace) - next_task = self._send_update_task(result, fetch_next) + next_task = self._send_update_task(results[0], fetch_next) return next_task def _send_update_task( @@ -581,7 +688,7 @@ def __init__( app_module: str, # Here the bool is used to indicate whether this is a normal fetch or is being called # during shutdown. - send_result_fn: Callable[[ProcessingResult, bool], InflightTaskActivation | None], + send_result_fn: Callable[[list[ProcessingResult], bool], InflightTaskActivation | None], mp_context: ForkContext | SpawnContext | ForkServerContext, max_child_task_count: int | None = None, concurrency: int = 1, @@ -590,11 +697,15 @@ def __init__( processing_pool_name: str | None = None, pod_name: str | None = None, process_type: str = "spawn", + update_in_batches: bool = False, ) -> None: self._concurrency = concurrency self._processing_pool_name = processing_pool_name or "unknown" self._pod_name = pod_name or "unknown" - self._send_result = send_result_fn + self._update_in_batches = update_in_batches + + self._send_result_fn = send_result_fn + self._max_child_task_count = max_child_task_count self._app_module = app_module app = import_app(app_module) @@ -614,7 +725,7 @@ def __init__( self._result_thread: threading.Thread | None = None self._spawn_children_thread: threading.Thread | None = None - def send_result(self, result: ProcessingResult, is_draining: bool = False) -> None: + def send_results(self, results: list[ProcessingResult], is_draining: bool = False) -> None: """ Call the passed in function. If is_draining is True, the function should not fetch a new task. That function should return: @@ -624,14 +735,15 @@ def send_result(self, result: ProcessingResult, is_draining: bool = False) -> No """ try: worker_full = is_draining or self._child_tasks.full() - next_task = self._send_result(result, worker_full) + next_task = self._send_result_fn(results, worker_full) if next_task: self.push_task(next_task) except RequeueException: logger.warning("activation status couldn't be updated") # This can cause an infinite loop if we are draining and the result fails to send if not is_draining: - self.put_result(result) + for result in results: + self.put_result(result) def start_result_thread(self) -> None: """ @@ -673,15 +785,26 @@ def result_thread() -> None: extra={"error": e, "processing_pool": self._processing_pool_name}, ) - try: - result = self._processed_tasks.get(timeout=1.0) - executor.submit(self.send_result, result, False) - except queue.Empty: - self._metrics.incr( - "taskworker.worker.result_thread.queue_empty", - tags={"processing_pool": self._processing_pool_name}, - ) - continue + results = [] + while True: + try: + result = self._processed_tasks.get(timeout=1.0) + if not self._update_in_batches: + executor.submit(self.send_results, [result], False) + break + else: + results.append(result) + if len(results) >= self._concurrency: + executor.submit(self.send_results, results, False) + results = [] + except queue.Empty: + if not results: + # Only increment if there was nothing in the queue at all + self._metrics.incr( + "taskworker.worker.result_thread.queue_empty", + tags={"processing_pool": self._processing_pool_name}, + ) + break self._result_thread = threading.Thread( name="send-result", target=result_thread, daemon=True @@ -794,7 +917,7 @@ def shutdown(self) -> None: while True: try: result = self._processed_tasks.get_nowait() - self.send_result(result, True) + self.send_results([result], True) except queue.Empty: break diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index b2ca19dd..ed9f79a8 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -41,6 +41,7 @@ from taskbroker_client.types import InflightTaskActivation, ProcessingResult from taskbroker_client.worker.producer import TaskProducer, _pending_futures from taskbroker_client.worker.worker import ( + BatchPushTaskWorker, PushTaskWorker, TaskWorker, TaskWorkerProcessingPool, @@ -535,6 +536,19 @@ def test_constructor_push_mode(self) -> None: self.assertTrue(taskworker.client is not None) self.assertEqual(taskworker._grpc_port, 50099) + def test_constructor_batch_push_mode(self) -> None: + taskworker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service="127.0.0.1:50051", + max_child_task_count=100, + process_type="fork", + grpc_port=50099, + update_in_batches=True, + ) + + self.assertTrue(taskworker.client is not None) + self.assertEqual(taskworker._grpc_port, 50099) + def test_push_worker_health_check_touches_while_idle(tmp_path: Path) -> None: taskworker = PushTaskWorker( @@ -559,6 +573,30 @@ def test_push_worker_health_check_touches_while_idle(tmp_path: Path) -> None: assert taskworker._health_check_thread is None +def test_batch_push_worker_health_check_touches_while_idle(tmp_path: Path) -> None: + taskworker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service="127.0.0.1:50051", + max_child_task_count=100, + process_type="fork", + health_check_file_path=str(tmp_path / "health"), + health_check_sec_per_touch=0.01, + update_in_batches=True, + ) + + with mock.patch.object(taskworker.client, "emit_health_check") as mock_emit: + taskworker._start_health_check_thread() + try: + start = time.time() + while mock_emit.call_count < 2 and time.time() - start < 1: + time.sleep(0.01) + finally: + taskworker._stop_health_check_thread() + + assert mock_emit.call_count >= 2 + assert taskworker._health_check_thread is None + + class TestWorkerServicer(TestCase): def test_push_task_success(self) -> None: taskworker = PushTaskWorker( @@ -586,6 +624,33 @@ def test_push_task_success(self) -> None: self.assertEqual(inflight.activation.id, SIMPLE_TASK.activation.id) self.assertEqual(inflight.host, "broker-host:50051") + def test_batch_push_task_success(self) -> None: + taskworker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service="127.0.0.1:50051", + max_child_task_count=100, + process_type="fork", + update_in_batches=True, + ) + with mock.patch.object( + taskworker.worker_pool, "push_task", return_value=True + ) as mock_push_task: + request = PushTaskRequest( + task=SIMPLE_TASK.activation, + callback_url="broker-host:50051", + ) + mock_context = mock.MagicMock() + servicer = WorkerServicer(taskworker.worker_pool) + + response = servicer.PushTask(request, mock_context) + + self.assertIsInstance(response, PushTaskResponse) + mock_context.abort.assert_not_called() + mock_push_task.assert_called_once_with(mock.ANY, timeout=5) + (inflight,) = mock_push_task.call_args[0] + self.assertEqual(inflight.activation.id, SIMPLE_TASK.activation.id) + self.assertEqual(inflight.host, "broker-host:50051") + def test_push_task_worker_busy(self) -> None: taskworker = PushTaskWorker( app_module="examples.app:app", @@ -608,6 +673,29 @@ def test_push_task_worker_busy(self) -> None: grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy" ) + def test_batch_push_task_worker_busy(self) -> None: + taskworker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service="127.0.0.1:50051", + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=1, + update_in_batches=True, + ) + with mock.patch.object(taskworker.worker_pool, "push_task", return_value=False): + request = PushTaskRequest( + task=SIMPLE_TASK.activation, + callback_url="broker-host:50051", + ) + mock_context = mock.MagicMock() + servicer = WorkerServicer(taskworker.worker_pool) + + servicer.PushTask(request, mock_context) + + mock_context.abort.assert_called_once_with( + grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy" + ) + @mock.patch("taskbroker_client.worker.workerchild.capture_checkin") def test_child_process_complete(mock_capture_checkin: mock.MagicMock) -> None: diff --git a/uv.lock b/uv.lock index 8b57a0b7..590a9e32 100644 --- a/uv.lock +++ b/uv.lock @@ -611,7 +611,7 @@ wheels = [ [[package]] name = "sentry-protos" -version = "0.16.1" +version = "0.26.1" source = { registry = "https://pypi.devinfra.sentry.io/simple" } dependencies = [ { name = "grpc-stubs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -619,7 +619,7 @@ dependencies = [ { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] wheels = [ - { url = "https://pypi.devinfra.sentry.io/wheels/sentry_protos-0.16.1-py3-none-any.whl", hash = "sha256:755a7cc71a0d8bef2a42a340cd1c35e2ee127e20dd71fed334d9fa88c0cb87a4" }, + { url = "https://pypi.devinfra.sentry.io/wheels/sentry_protos-0.26.1-py3-none-any.whl", hash = "sha256:66bc22be8ac3efeceeb09af30a3ca55943ad83a3d28905263255f74a02206822" }, ] [[package]] @@ -708,7 +708,7 @@ dev = [ { name = "pytest", specifier = ">=9.0.3" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "sentry-devenv", specifier = ">=1.22.2" }, - { name = "sentry-protos", specifier = ">=0.15.0" }, + { name = "sentry-protos", specifier = ">=0.22.1" }, { name = "types-protobuf", specifier = ">=5.27.0.20240626,<6.0.0" }, { name = "types-pyyaml", specifier = ">=6.0.12.20241230" }, ] @@ -767,7 +767,7 @@ requires-dist = [ { name = "redis", specifier = ">=3.4.1" }, { name = "redis-py-cluster", specifier = ">=2.1.0" }, { name = "sentry-arroyo", specifier = ">=2.38.7" }, - { name = "sentry-protos", specifier = ">=0.15.0" }, + { name = "sentry-protos", specifier = ">=0.26.1" }, { name = "sentry-sdk", extras = ["http2"], specifier = ">=2.43.0" }, { name = "setuptools", marker = "extra == 'examples'", specifier = ">=80.0" }, { name = "zstandard", specifier = ">=0.18.0" }, From 71fb4c2a6a930474be20c872ec2de49f31751a4d Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 11 Jun 2026 13:26:26 -0400 Subject: [PATCH 2/4] fixes --- clients/python/src/examples/cli.py | 1 + clients/python/src/taskbroker_client/worker/worker.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/clients/python/src/examples/cli.py b/clients/python/src/examples/cli.py index 687cba12..35717a29 100644 --- a/clients/python/src/examples/cli.py +++ b/clients/python/src/examples/cli.py @@ -103,6 +103,7 @@ def worker( processing_pool_name="examples", process_type="forkserver", grpc_port=grpc_port, + update_in_batches=True, ) elif push_mode: worker = PushTaskWorker( diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index 45c84d4c..7df94b29 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -804,6 +804,9 @@ def result_thread() -> None: "taskworker.worker.result_thread.queue_empty", tags={"processing_pool": self._processing_pool_name}, ) + elif self._update_in_batches: + executor.submit(self.send_results, results, False) + results = [] break self._result_thread = threading.Thread( From 1e17e7f280d7c386002989249ba3b587c4b65c46 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 11 Jun 2026 13:43:44 -0400 Subject: [PATCH 3/4] fix and add tests --- clients/python/tests/worker/test_worker.py | 96 ++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index ed9f79a8..2ce3130a 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -297,6 +297,51 @@ ) +def _make_processing_result(task_id: str) -> ProcessingResult: + return ProcessingResult( + task_id=task_id, + status=TASK_ACTIVATION_STATUS_COMPLETE, + host="localhost:50051", + receive_timestamp=0, + ) + + +class _SendResultCapture: + def __init__(self) -> None: + self.send_calls: list[tuple[list[ProcessingResult], bool]] = [] + self._lock = threading.Lock() + + def __call__(self, results: list[ProcessingResult], is_draining: bool) -> None: + with self._lock: + self.send_calls.append((list(results), is_draining)) + return None + + def wait_for_calls(self, expected: int, timeout: float = 5) -> None: + start = time.time() + while len(self.send_calls) < expected and time.time() - start < timeout: + time.sleep(0.01) + if len(self.send_calls) < expected: + raise AssertionError(f"Expected {expected} send calls, got {len(self.send_calls)}") + + +def _make_result_thread_pool( + capture: _SendResultCapture, + *, + concurrency: int = 3, + update_in_batches: bool, +) -> TaskWorkerProcessingPool: + return TaskWorkerProcessingPool( + app_module="examples.app:app", + send_result_fn=capture, + mp_context=get_context("fork"), + max_child_task_count=100, + concurrency=concurrency, + processing_pool_name="test", + process_type="fork", + update_in_batches=update_in_batches, + ) + + class TestTaskWorker(TestCase): def test_fetch_task(self) -> None: taskworker = TaskWorker( @@ -475,6 +520,57 @@ def test_push_task_queue(self) -> None: result = taskworker.push_task(SIMPLE_TASK, timeout=1) self.assertFalse(result) + def test_result_thread_sends_full_batch(self) -> None: + capture = _SendResultCapture() + concurrency = 3 + pool = _make_result_thread_pool(capture, concurrency=concurrency, update_in_batches=True) + try: + pool.start_result_thread() + + for i in range(concurrency): + pool.put_result(_make_processing_result(str(i))) + + capture.wait_for_calls(1) + batch, is_draining = capture.send_calls[0] + self.assertEqual(len(batch), concurrency) + self.assertEqual({result.task_id for result in batch}, {"0", "1", "2"}) + self.assertFalse(is_draining) + finally: + pool.shutdown() + + def test_result_thread_flushes_partial_batch_on_queue_empty(self) -> None: + capture = _SendResultCapture() + pool = _make_result_thread_pool(capture, update_in_batches=True) + try: + pool.start_result_thread() + + pool.put_result(_make_processing_result("partial-1")) + pool.put_result(_make_processing_result("partial-2")) + + capture.wait_for_calls(1, timeout=3) + batch, is_draining = capture.send_calls[0] + self.assertEqual(len(batch), 2) + self.assertEqual({result.task_id for result in batch}, {"partial-1", "partial-2"}) + self.assertFalse(is_draining) + finally: + pool.shutdown() + + def test_result_thread_sends_results_individually_without_batching(self) -> None: + capture = _SendResultCapture() + pool = _make_result_thread_pool(capture, update_in_batches=False) + try: + pool.start_result_thread() + + pool.put_result(_make_processing_result("single")) + + capture.wait_for_calls(1) + batch, is_draining = capture.send_calls[0] + self.assertEqual(len(batch), 1) + self.assertEqual(batch[0].task_id, "single") + self.assertFalse(is_draining) + finally: + pool.shutdown() + def test_run_once_current_task_state(self) -> None: # Run a task that uses retry_task() helper # to raise and catch a NoRetriesRemainingError From 4b03b1c9c2ea0681f63bdc79cff59666c0c23925 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 11 Jun 2026 13:48:04 -0400 Subject: [PATCH 4/4] logging --- clients/python/src/taskbroker_client/worker/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index 7df94b29..be69db2e 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -462,7 +462,7 @@ def _send_results( ) logger.info( "taskworker.send_update_task_batch.temporarily_unavailable", - extra={"task_id": result.task_id, "error": str(e)}, + extra={"task_ids": [result.task_id for result in results], "error": str(e)}, ) raise RequeueException(f"Failed to update task: {e}")