diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index f8a84ff3..9ee0ba7f 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" dependencies = [ "sentry-arroyo>=2.38.7", "sentry-sdk[http2]>=2.43.0", - "sentry-protos>=0.15.0", + "sentry-protos>=0.26.1", "confluent_kafka>=2.3.0", "cronsim>=2.6", "grpcio>=1.67.1", diff --git a/clients/python/src/examples/cli.py b/clients/python/src/examples/cli.py index 4c9fdc60..35717a29 100644 --- a/clients/python/src/examples/cli.py +++ b/clients/python/src/examples/cli.py @@ -76,18 +76,37 @@ def scheduler() -> None: @click.option( "--push-mode", help="Whether to run in PUSH or PULL mode.", default=False, is_flag=True ) +@click.option( + "--batch-push-mode", help="Whether to run in BATCH PUSH mode.", default=False, is_flag=True +) @click.option( "--grpc-port", help="Port for the gRPC server to listen on.", default=50052, type=int, ) -def worker(rpc_host: str, concurrency: int, push_mode: bool, grpc_port: int) -> None: - from taskbroker_client.worker import PushTaskWorker, TaskWorker +def worker( + rpc_host: str, concurrency: int, push_mode: bool, batch_push_mode: bool, grpc_port: int +) -> None: + from taskbroker_client.worker import BatchPushTaskWorker, PushTaskWorker, TaskWorker click.echo("Starting worker") - if push_mode: - worker: PushTaskWorker | TaskWorker = PushTaskWorker( + if batch_push_mode: + worker: PushTaskWorker | TaskWorker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service=rpc_host, + max_child_task_count=100, + concurrency=concurrency, + child_tasks_queue_maxsize=concurrency * 2, + result_queue_maxsize=concurrency * 2, + rebalance_after=32, + processing_pool_name="examples", + process_type="forkserver", + grpc_port=grpc_port, + update_in_batches=True, + ) + elif push_mode: + worker = PushTaskWorker( app_module="examples.app:app", broker_service=rpc_host, max_child_task_count=100, diff --git a/clients/python/src/taskbroker_client/worker/__init__.py b/clients/python/src/taskbroker_client/worker/__init__.py index d94f2b62..79f70886 100644 --- a/clients/python/src/taskbroker_client/worker/__init__.py +++ b/clients/python/src/taskbroker_client/worker/__init__.py @@ -1,3 +1,3 @@ -from .worker import PushTaskWorker, TaskWorker +from .worker import BatchPushTaskWorker, PushTaskWorker, TaskWorker -__all__ = ("TaskWorker", "PushTaskWorker") +__all__ = ("TaskWorker", "PushTaskWorker", "BatchPushTaskWorker") diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index 356a4f75..baf609d6 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -493,111 +493,3 @@ def update_task( receive_timestamp=time.monotonic(), ) return None - - -class PushTaskbrokerClient: - """ - Taskworker RPC client wrapper - - Push brokers are a deployment so they don't need to be connected to individually. There is one service provided - that works for all the brokers. - """ - - def __init__( - self, - service: str, - application: str, - metrics: MetricsBackend, - health_check_settings: HealthCheckSettings | None = None, - rpc_secret: str | None = None, - grpc_config: str | None = None, - ) -> None: - self._application = application - self._service = service - self._rpc_secret = rpc_secret - self._metrics = metrics - - self._grpc_options: list[tuple[str, Any]] = [ - ("grpc.max_receive_message_length", MAX_ACTIVATION_SIZE) - ] - if grpc_config: - self._grpc_options.append(("grpc.service_config", grpc_config)) - - logger.info( - "taskworker.push_client.start", - extra={"service": service, "options": self._grpc_options}, - ) - - self._stub = self._connect_to_host(service) - - self._health_check_settings = health_check_settings - self._timestamp_since_touch_lock = threading.Lock() - self._timestamp_since_touch = 0.0 - - def _emit_health_check(self) -> None: - if self._health_check_settings is None: - return - - with self._timestamp_since_touch_lock: - cur_time = time.time() - if ( - cur_time - self._timestamp_since_touch - < self._health_check_settings.touch_interval_sec - ): - return - - self._health_check_settings.file_path.touch() - self._metrics.incr( - "taskworker.client.health_check.touched", - ) - self._timestamp_since_touch = cur_time - - def _connect_to_host(self, host: str) -> ConsumerServiceStub: - logger.info("taskworker.push_client.connect", extra={"host": host}) - channel = grpc.insecure_channel(host, options=self._grpc_options) - secrets = parse_rpc_secret_list(self._rpc_secret) - if secrets: - channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) - return ConsumerServiceStub(channel) - - def emit_health_check(self) -> None: - self._emit_health_check() - - def update_task( - self, - processing_result: ProcessingResult, - ) -> None: - """ - Update the status for a given task activation. - """ - self._emit_health_check() - - request = SetTaskStatusRequest( - id=processing_result.task_id, - status=processing_result.status, - fetch_next_task=None, - max_attempts=processing_result.max_attempts, - delay_on_retry=processing_result.delay_on_retry, - ) - - retries = 0 - exception = None - while retries < 3: - try: - with self._metrics.timer( - "taskworker.update_task.rpc", tags={"service": self._service} - ): - self._stub.SetTaskStatus(request) - exception = None - break - except grpc.RpcError as err: - exception = err - self._metrics.incr( - "taskworker.client.rpc_error", - tags={"method": "SetTaskStatus", "status": err.code().name}, - ) - finally: - retries += 1 - - if exception: - raise exception diff --git a/clients/python/src/taskbroker_client/worker/push_clients.py b/clients/python/src/taskbroker_client/worker/push_clients.py new file mode 100644 index 00000000..5a13b6de --- /dev/null +++ b/clients/python/src/taskbroker_client/worker/push_clients.py @@ -0,0 +1,192 @@ +import logging +import threading +import time +from typing import TYPE_CHECKING, Any + +import grpc +from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( + SetBatchActivationStatusRequest, + SetTaskStatusRequest, +) +from sentry_protos.taskbroker.v1.taskbroker_pb2_grpc import ConsumerServiceStub + +from taskbroker_client.metrics import MetricsBackend +from taskbroker_client.types import ProcessingResult +from taskbroker_client.worker.client import ( + MAX_ACTIVATION_SIZE, + HealthCheckSettings, + RequestSignatureInterceptor, + parse_rpc_secret_list, +) + +if TYPE_CHECKING: + ServerInterceptor = grpc.ServerInterceptor[Any, Any] +else: + ServerInterceptor = grpc.ServerInterceptor + +logger = logging.getLogger(__name__) + + +class PushTaskbrokerClient: + """ + Taskworker RPC client wrapper + + Push brokers are a deployment so they don't need to be connected to individually. There is one service provided + that works for all the brokers. + """ + + def __init__( + self, + service: str, + application: str, + metrics: MetricsBackend, + health_check_settings: HealthCheckSettings | None = None, + rpc_secret: str | None = None, + grpc_config: str | None = None, + ) -> None: + self._application = application + self._service = service + self._rpc_secret = rpc_secret + self._metrics = metrics + + self._grpc_options: list[tuple[str, Any]] = [ + ("grpc.max_receive_message_length", MAX_ACTIVATION_SIZE) + ] + if grpc_config: + self._grpc_options.append(("grpc.service_config", grpc_config)) + + logger.info( + "taskworker.push_client.start", + extra={"service": service, "options": self._grpc_options}, + ) + + self._stub = self._connect_to_host(service) + + self._health_check_settings = health_check_settings + self._timestamp_since_touch_lock = threading.Lock() + self._timestamp_since_touch = 0.0 + + def _emit_health_check(self) -> None: + if self._health_check_settings is None: + return + + with self._timestamp_since_touch_lock: + cur_time = time.time() + if ( + cur_time - self._timestamp_since_touch + < self._health_check_settings.touch_interval_sec + ): + return + + self._health_check_settings.file_path.touch() + self._metrics.incr( + "taskworker.client.health_check.touched", + ) + self._timestamp_since_touch = cur_time + + def _connect_to_host(self, host: str) -> ConsumerServiceStub: + logger.info("taskworker.push_client.connect", extra={"host": host}) + channel = grpc.insecure_channel(host, options=self._grpc_options) + secrets = parse_rpc_secret_list(self._rpc_secret) + if secrets: + channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) + return ConsumerServiceStub(channel) + + def emit_health_check(self) -> None: + self._emit_health_check() + + def update_tasks(self, processing_results: list[ProcessingResult]) -> None: + for processing_result in processing_results: + self._update_task_single(processing_result) + + def _update_task_single( + self, + processing_result: ProcessingResult, + ) -> None: + """ + Update the status for a given task activation. + """ + self._emit_health_check() + + request = SetTaskStatusRequest( + id=processing_result.task_id, + status=processing_result.status, + fetch_next_task=None, + max_attempts=processing_result.max_attempts, + delay_on_retry=processing_result.delay_on_retry, + ) + + retries = 0 + exception = None + while retries < 3: + try: + with self._metrics.timer( + "taskworker.update_task.rpc", tags={"service": self._service} + ): + self._stub.SetTaskStatus(request) + exception = None + break + except grpc.RpcError as err: + exception = err + self._metrics.incr( + "taskworker.client.rpc_error", + tags={"method": "SetTaskStatus", "status": err.code().name}, + ) + finally: + retries += 1 + + if exception: + raise exception + + +class BatchPushTaskbrokerClient(PushTaskbrokerClient): + """ + Taskworker RPC client wrapper + + Push brokers are a deployment so they don't need to be connected to individually. There is one service provided + that works for all the brokers. This client pushes batches of activation updates. + """ + + def update_tasks( + self, + processing_results: list[ProcessingResult], + ) -> None: + """ + Update the status for a given task activation. + """ + self._emit_health_check() + + request = SetBatchActivationStatusRequest( + updates=[ + SetTaskStatusRequest( + id=processing_result.task_id, + status=processing_result.status, + fetch_next_task=None, + max_attempts=processing_result.max_attempts, + delay_on_retry=processing_result.delay_on_retry, + ) + for processing_result in processing_results + ] + ) + + retries = 0 + exception = None + while retries < 3: + try: + with self._metrics.timer( + "taskworker.update_task_batch.rpc", tags={"service": self._service} + ): + self._stub.SetBatchActivationStatus(request) + exception = None + break + except grpc.RpcError as err: + exception = err + self._metrics.incr( + "taskworker.client.rpc_error", + tags={"method": "SetBatchActivationStatus", "status": err.code().name}, + ) + finally: + retries += 1 + + if exception: + raise exception diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index f69ecc4b..be69db2e 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -31,15 +31,16 @@ MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE, WORKER_CHILD_JOIN_TIMEOUT_SEC, ) +from taskbroker_client.metrics import MetricsBackend from taskbroker_client.types import InflightTaskActivation, ProcessingResult from taskbroker_client.worker.client import ( HealthCheckSettings, HostTemporarilyUnavailable, - PushTaskbrokerClient, RequestSignatureServerInterceptor, TaskbrokerClient, parse_rpc_secret_list, ) +from taskbroker_client.worker.push_clients import BatchPushTaskbrokerClient, PushTaskbrokerClient from taskbroker_client.worker.workerchild import child_process if TYPE_CHECKING: @@ -138,6 +139,7 @@ def __init__( health_check_file_path: str | None = None, health_check_sec_per_touch: float = DEFAULT_WORKER_HEALTH_CHECK_SEC_PER_TOUCH, grpc_port: int = 50052, + update_in_batches: bool = False, ) -> None: app = import_app(app_module) @@ -153,7 +155,7 @@ def __init__( self.worker_pool = TaskWorkerProcessingPool( app_module=app_module, mp_context=self._mp_context, - send_result_fn=self._send_result, + send_result_fn=self._send_results, max_child_task_count=max_child_task_count, concurrency=concurrency, child_tasks_queue_maxsize=child_tasks_queue_maxsize, @@ -161,11 +163,12 @@ def __init__( processing_pool_name=processing_pool_name, pod_name=pod_name, process_type=process_type, + update_in_batches=update_in_batches, ) logger.info("Running in PUSH mode") - self.client = PushTaskbrokerClient( + self.client = self._create_client( service=broker_service, application=app.name, metrics=app.metrics, @@ -193,12 +196,34 @@ def __init__( self._grpc_port = grpc_port self._grpc_secrets = parse_rpc_secret_list(app.config["rpc_secret"]) - def _send_result( - self, result: ProcessingResult, is_draining: bool = False + def _create_client( + self, + service: str, + application: str, + metrics: MetricsBackend, + health_check_settings: HealthCheckSettings | None = None, + rpc_secret: str | None = None, + grpc_config: str | None = None, + ) -> PushTaskbrokerClient: + return PushTaskbrokerClient( + service=service, + application=application, + metrics=metrics, + health_check_settings=health_check_settings, + rpc_secret=rpc_secret, + grpc_config=grpc_config, + ) + + def _send_results( + self, results: list[ProcessingResult], is_draining: bool = False ) -> InflightTaskActivation | None: """ Send a result to the broker. If the set has failed before, sleep briefly before retrying. """ + assert ( + len(results) == 1 + ), "Only one result can be sent at a time with the regular push client" + result = results[0] self._metrics.distribution( "taskworker.worker.complete_duration", time.monotonic() - result.receive_timestamp, @@ -217,7 +242,7 @@ def _send_result( self._grpc_sync_event.wait(self._setstatus_backoff_seconds) try: - self.client.update_task(result) + self.client.update_tasks([result]) self._setstatus_backoff_seconds = 0 return None except grpc.RpcError as e: @@ -365,6 +390,85 @@ def shutdown(self) -> None: self.worker_pool.shutdown() +class BatchPushTaskWorker(PushTaskWorker): + def __init__(self, *args: Any, **kwargs: Any) -> None: + assert ( + kwargs["update_in_batches"] is True + ), "BatchPushTaskWorker must be initialized with update_in_batches=True" + super().__init__(*args, **kwargs) + + def _create_client( + self, + service: str, + application: str, + metrics: MetricsBackend, + health_check_settings: HealthCheckSettings | None = None, + rpc_secret: str | None = None, + grpc_config: str | None = None, + ) -> PushTaskbrokerClient: + return BatchPushTaskbrokerClient( + service=service, + application=application, + metrics=metrics, + health_check_settings=health_check_settings, + rpc_secret=rpc_secret, + grpc_config=grpc_config, + ) + + def _send_results( + self, results: list[ProcessingResult], is_draining: bool = False + ) -> InflightTaskActivation | None: + """ + Send a result to the broker. If the set has failed before, sleep briefly before retrying. + """ + for result in results: + self._metrics.distribution( + "taskworker.worker.complete_duration", + time.monotonic() - result.receive_timestamp, + tags={"processing_pool": self._processing_pool_name}, + ) + self._metrics.distribution( + "taskworker.worker.update_status_batch_size", + len(results), + tags={"processing_pool": self._processing_pool_name}, + ) + + logger.debug( + "taskworker.send_update_task_batch.batch_sent", + extra={ + "results": [result.task_id for result in results], + "processing_pool": self._processing_pool_name, + }, + ) + # Use the shutdown_event as a sleep mechanism + self._grpc_sync_event.wait(self._setstatus_backoff_seconds) + + try: + self.client.update_tasks(results) + self._setstatus_backoff_seconds = 0 + return None + except grpc.RpcError as e: + self._setstatus_backoff_seconds = min(self._setstatus_backoff_seconds + 1, 10) + logger.warning( + "taskworker.send_update_task_batch.failed", + extra={"results": [result.task_id for result in results], "error": e}, + ) + if e.code() != grpc.StatusCode.NOT_FOUND: + # If the task was not found, we can't update it, so we should just return None + raise RequeueException(f"Failed to update task batch: {e}") + except HostTemporarilyUnavailable as e: + self._setstatus_backoff_seconds = min( + self._setstatus_backoff_seconds + 4, MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE + ) + logger.info( + "taskworker.send_update_task_batch.temporarily_unavailable", + extra={"task_ids": [result.task_id for result in results], "error": str(e)}, + ) + raise RequeueException(f"Failed to update task: {e}") + + return None + + class TaskWorker: """ A TaskWorker fetches tasks from a taskworker RPC host and handles executing task activations. @@ -406,7 +510,7 @@ def __init__( self.worker_pool = TaskWorkerProcessingPool( app_module=app_module, mp_context=self._mp_context, - send_result_fn=self._send_result, + send_result_fn=self._send_results, max_child_task_count=max_child_task_count, concurrency=concurrency, child_tasks_queue_maxsize=child_tasks_queue_maxsize, @@ -484,19 +588,22 @@ def _add_task(self) -> bool: return False - def _send_result( - self, result: ProcessingResult, is_draining: bool = False + def _send_results( + self, results: list[ProcessingResult], is_draining: bool = False ) -> InflightTaskActivation | None: """ Send a result to the broker and conditionally fetch an additional task. Return a boolean indicating whether the result was sent successfully. """ + assert ( + len(results) == 1 + ), "Only one result can be sent at a time with the regular pull client" self._metrics.distribution( "taskworker.worker.complete_duration", - time.monotonic() - result.receive_timestamp, + time.monotonic() - results[0].receive_timestamp, tags={"processing_pool": self._processing_pool_name}, ) fetch_next = None if is_draining else FetchNextTask(namespace=self._namespace) - next_task = self._send_update_task(result, fetch_next) + next_task = self._send_update_task(results[0], fetch_next) return next_task def _send_update_task( @@ -581,7 +688,7 @@ def __init__( app_module: str, # Here the bool is used to indicate whether this is a normal fetch or is being called # during shutdown. - send_result_fn: Callable[[ProcessingResult, bool], InflightTaskActivation | None], + send_result_fn: Callable[[list[ProcessingResult], bool], InflightTaskActivation | None], mp_context: ForkContext | SpawnContext | ForkServerContext, max_child_task_count: int | None = None, concurrency: int = 1, @@ -590,11 +697,15 @@ def __init__( processing_pool_name: str | None = None, pod_name: str | None = None, process_type: str = "spawn", + update_in_batches: bool = False, ) -> None: self._concurrency = concurrency self._processing_pool_name = processing_pool_name or "unknown" self._pod_name = pod_name or "unknown" - self._send_result = send_result_fn + self._update_in_batches = update_in_batches + + self._send_result_fn = send_result_fn + self._max_child_task_count = max_child_task_count self._app_module = app_module app = import_app(app_module) @@ -614,7 +725,7 @@ def __init__( self._result_thread: threading.Thread | None = None self._spawn_children_thread: threading.Thread | None = None - def send_result(self, result: ProcessingResult, is_draining: bool = False) -> None: + def send_results(self, results: list[ProcessingResult], is_draining: bool = False) -> None: """ Call the passed in function. If is_draining is True, the function should not fetch a new task. That function should return: @@ -624,14 +735,15 @@ def send_result(self, result: ProcessingResult, is_draining: bool = False) -> No """ try: worker_full = is_draining or self._child_tasks.full() - next_task = self._send_result(result, worker_full) + next_task = self._send_result_fn(results, worker_full) if next_task: self.push_task(next_task) except RequeueException: logger.warning("activation status couldn't be updated") # This can cause an infinite loop if we are draining and the result fails to send if not is_draining: - self.put_result(result) + for result in results: + self.put_result(result) def start_result_thread(self) -> None: """ @@ -673,15 +785,29 @@ def result_thread() -> None: extra={"error": e, "processing_pool": self._processing_pool_name}, ) - try: - result = self._processed_tasks.get(timeout=1.0) - executor.submit(self.send_result, result, False) - except queue.Empty: - self._metrics.incr( - "taskworker.worker.result_thread.queue_empty", - tags={"processing_pool": self._processing_pool_name}, - ) - continue + results = [] + while True: + try: + result = self._processed_tasks.get(timeout=1.0) + if not self._update_in_batches: + executor.submit(self.send_results, [result], False) + break + else: + results.append(result) + if len(results) >= self._concurrency: + executor.submit(self.send_results, results, False) + results = [] + except queue.Empty: + if not results: + # Only increment if there was nothing in the queue at all + self._metrics.incr( + "taskworker.worker.result_thread.queue_empty", + tags={"processing_pool": self._processing_pool_name}, + ) + elif self._update_in_batches: + executor.submit(self.send_results, results, False) + results = [] + break self._result_thread = threading.Thread( name="send-result", target=result_thread, daemon=True @@ -794,7 +920,7 @@ def shutdown(self) -> None: while True: try: result = self._processed_tasks.get_nowait() - self.send_result(result, True) + self.send_results([result], True) except queue.Empty: break diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index b2ca19dd..2ce3130a 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -41,6 +41,7 @@ from taskbroker_client.types import InflightTaskActivation, ProcessingResult from taskbroker_client.worker.producer import TaskProducer, _pending_futures from taskbroker_client.worker.worker import ( + BatchPushTaskWorker, PushTaskWorker, TaskWorker, TaskWorkerProcessingPool, @@ -296,6 +297,51 @@ ) +def _make_processing_result(task_id: str) -> ProcessingResult: + return ProcessingResult( + task_id=task_id, + status=TASK_ACTIVATION_STATUS_COMPLETE, + host="localhost:50051", + receive_timestamp=0, + ) + + +class _SendResultCapture: + def __init__(self) -> None: + self.send_calls: list[tuple[list[ProcessingResult], bool]] = [] + self._lock = threading.Lock() + + def __call__(self, results: list[ProcessingResult], is_draining: bool) -> None: + with self._lock: + self.send_calls.append((list(results), is_draining)) + return None + + def wait_for_calls(self, expected: int, timeout: float = 5) -> None: + start = time.time() + while len(self.send_calls) < expected and time.time() - start < timeout: + time.sleep(0.01) + if len(self.send_calls) < expected: + raise AssertionError(f"Expected {expected} send calls, got {len(self.send_calls)}") + + +def _make_result_thread_pool( + capture: _SendResultCapture, + *, + concurrency: int = 3, + update_in_batches: bool, +) -> TaskWorkerProcessingPool: + return TaskWorkerProcessingPool( + app_module="examples.app:app", + send_result_fn=capture, + mp_context=get_context("fork"), + max_child_task_count=100, + concurrency=concurrency, + processing_pool_name="test", + process_type="fork", + update_in_batches=update_in_batches, + ) + + class TestTaskWorker(TestCase): def test_fetch_task(self) -> None: taskworker = TaskWorker( @@ -474,6 +520,57 @@ def test_push_task_queue(self) -> None: result = taskworker.push_task(SIMPLE_TASK, timeout=1) self.assertFalse(result) + def test_result_thread_sends_full_batch(self) -> None: + capture = _SendResultCapture() + concurrency = 3 + pool = _make_result_thread_pool(capture, concurrency=concurrency, update_in_batches=True) + try: + pool.start_result_thread() + + for i in range(concurrency): + pool.put_result(_make_processing_result(str(i))) + + capture.wait_for_calls(1) + batch, is_draining = capture.send_calls[0] + self.assertEqual(len(batch), concurrency) + self.assertEqual({result.task_id for result in batch}, {"0", "1", "2"}) + self.assertFalse(is_draining) + finally: + pool.shutdown() + + def test_result_thread_flushes_partial_batch_on_queue_empty(self) -> None: + capture = _SendResultCapture() + pool = _make_result_thread_pool(capture, update_in_batches=True) + try: + pool.start_result_thread() + + pool.put_result(_make_processing_result("partial-1")) + pool.put_result(_make_processing_result("partial-2")) + + capture.wait_for_calls(1, timeout=3) + batch, is_draining = capture.send_calls[0] + self.assertEqual(len(batch), 2) + self.assertEqual({result.task_id for result in batch}, {"partial-1", "partial-2"}) + self.assertFalse(is_draining) + finally: + pool.shutdown() + + def test_result_thread_sends_results_individually_without_batching(self) -> None: + capture = _SendResultCapture() + pool = _make_result_thread_pool(capture, update_in_batches=False) + try: + pool.start_result_thread() + + pool.put_result(_make_processing_result("single")) + + capture.wait_for_calls(1) + batch, is_draining = capture.send_calls[0] + self.assertEqual(len(batch), 1) + self.assertEqual(batch[0].task_id, "single") + self.assertFalse(is_draining) + finally: + pool.shutdown() + def test_run_once_current_task_state(self) -> None: # Run a task that uses retry_task() helper # to raise and catch a NoRetriesRemainingError @@ -535,6 +632,19 @@ def test_constructor_push_mode(self) -> None: self.assertTrue(taskworker.client is not None) self.assertEqual(taskworker._grpc_port, 50099) + def test_constructor_batch_push_mode(self) -> None: + taskworker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service="127.0.0.1:50051", + max_child_task_count=100, + process_type="fork", + grpc_port=50099, + update_in_batches=True, + ) + + self.assertTrue(taskworker.client is not None) + self.assertEqual(taskworker._grpc_port, 50099) + def test_push_worker_health_check_touches_while_idle(tmp_path: Path) -> None: taskworker = PushTaskWorker( @@ -559,6 +669,30 @@ def test_push_worker_health_check_touches_while_idle(tmp_path: Path) -> None: assert taskworker._health_check_thread is None +def test_batch_push_worker_health_check_touches_while_idle(tmp_path: Path) -> None: + taskworker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service="127.0.0.1:50051", + max_child_task_count=100, + process_type="fork", + health_check_file_path=str(tmp_path / "health"), + health_check_sec_per_touch=0.01, + update_in_batches=True, + ) + + with mock.patch.object(taskworker.client, "emit_health_check") as mock_emit: + taskworker._start_health_check_thread() + try: + start = time.time() + while mock_emit.call_count < 2 and time.time() - start < 1: + time.sleep(0.01) + finally: + taskworker._stop_health_check_thread() + + assert mock_emit.call_count >= 2 + assert taskworker._health_check_thread is None + + class TestWorkerServicer(TestCase): def test_push_task_success(self) -> None: taskworker = PushTaskWorker( @@ -586,6 +720,33 @@ def test_push_task_success(self) -> None: self.assertEqual(inflight.activation.id, SIMPLE_TASK.activation.id) self.assertEqual(inflight.host, "broker-host:50051") + def test_batch_push_task_success(self) -> None: + taskworker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service="127.0.0.1:50051", + max_child_task_count=100, + process_type="fork", + update_in_batches=True, + ) + with mock.patch.object( + taskworker.worker_pool, "push_task", return_value=True + ) as mock_push_task: + request = PushTaskRequest( + task=SIMPLE_TASK.activation, + callback_url="broker-host:50051", + ) + mock_context = mock.MagicMock() + servicer = WorkerServicer(taskworker.worker_pool) + + response = servicer.PushTask(request, mock_context) + + self.assertIsInstance(response, PushTaskResponse) + mock_context.abort.assert_not_called() + mock_push_task.assert_called_once_with(mock.ANY, timeout=5) + (inflight,) = mock_push_task.call_args[0] + self.assertEqual(inflight.activation.id, SIMPLE_TASK.activation.id) + self.assertEqual(inflight.host, "broker-host:50051") + def test_push_task_worker_busy(self) -> None: taskworker = PushTaskWorker( app_module="examples.app:app", @@ -608,6 +769,29 @@ def test_push_task_worker_busy(self) -> None: grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy" ) + def test_batch_push_task_worker_busy(self) -> None: + taskworker = BatchPushTaskWorker( + app_module="examples.app:app", + broker_service="127.0.0.1:50051", + max_child_task_count=100, + process_type="fork", + child_tasks_queue_maxsize=1, + update_in_batches=True, + ) + with mock.patch.object(taskworker.worker_pool, "push_task", return_value=False): + request = PushTaskRequest( + task=SIMPLE_TASK.activation, + callback_url="broker-host:50051", + ) + mock_context = mock.MagicMock() + servicer = WorkerServicer(taskworker.worker_pool) + + servicer.PushTask(request, mock_context) + + mock_context.abort.assert_called_once_with( + grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy" + ) + @mock.patch("taskbroker_client.worker.workerchild.capture_checkin") def test_child_process_complete(mock_capture_checkin: mock.MagicMock) -> None: diff --git a/uv.lock b/uv.lock index 8b57a0b7..590a9e32 100644 --- a/uv.lock +++ b/uv.lock @@ -611,7 +611,7 @@ wheels = [ [[package]] name = "sentry-protos" -version = "0.16.1" +version = "0.26.1" source = { registry = "https://pypi.devinfra.sentry.io/simple" } dependencies = [ { name = "grpc-stubs", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -619,7 +619,7 @@ dependencies = [ { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] wheels = [ - { url = "https://pypi.devinfra.sentry.io/wheels/sentry_protos-0.16.1-py3-none-any.whl", hash = "sha256:755a7cc71a0d8bef2a42a340cd1c35e2ee127e20dd71fed334d9fa88c0cb87a4" }, + { url = "https://pypi.devinfra.sentry.io/wheels/sentry_protos-0.26.1-py3-none-any.whl", hash = "sha256:66bc22be8ac3efeceeb09af30a3ca55943ad83a3d28905263255f74a02206822" }, ] [[package]] @@ -708,7 +708,7 @@ dev = [ { name = "pytest", specifier = ">=9.0.3" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "sentry-devenv", specifier = ">=1.22.2" }, - { name = "sentry-protos", specifier = ">=0.15.0" }, + { name = "sentry-protos", specifier = ">=0.22.1" }, { name = "types-protobuf", specifier = ">=5.27.0.20240626,<6.0.0" }, { name = "types-pyyaml", specifier = ">=6.0.12.20241230" }, ] @@ -767,7 +767,7 @@ requires-dist = [ { name = "redis", specifier = ">=3.4.1" }, { name = "redis-py-cluster", specifier = ">=2.1.0" }, { name = "sentry-arroyo", specifier = ">=2.38.7" }, - { name = "sentry-protos", specifier = ">=0.15.0" }, + { name = "sentry-protos", specifier = ">=0.26.1" }, { name = "sentry-sdk", extras = ["http2"], specifier = ">=2.43.0" }, { name = "setuptools", marker = "extra == 'examples'", specifier = ">=80.0" }, { name = "zstandard", specifier = ">=0.18.0" },