diff --git a/clients/python/src/taskbroker_client/worker/workerchild.py b/clients/python/src/taskbroker_client/worker/workerchild.py index 52b1941c..e8d7a8a3 100644 --- a/clients/python/src/taskbroker_client/worker/workerchild.py +++ b/clients/python/src/taskbroker_client/worker/workerchild.py @@ -8,6 +8,7 @@ import time from collections.abc import Callable, Generator, Sequence from dataclasses import dataclass +from functools import partial from multiprocessing.synchronize import Event from types import FrameType from typing import Any @@ -294,43 +295,57 @@ def get_oldest_pending_activation() -> ActivationWithPendingFutures | None: oldest = task return oldest - def check_task_future_completion() -> None: - if len(pending_task_futures) > 0: - # Records how many activations with pending producer futures - # the worker child has. Only records when there are pending activations. - metrics.gauge( - "taskworker.worker.activations_with_pending_futures", - len(pending_task_futures), - tags={ - "processing_pool": processing_pool_name, - }, - ) - for task in pending_task_futures.copy(): - future_status = [f.done() for fut in task.pending_futures.values() for f in fut] - if all(future_status): - await_task_futures(task) - else: - # How many futures are still pending in this task + def check_task_future_completion( + shutdown_event: Event, local_shutdown: threading.Event + ) -> None: + while not shutdown_event.is_set() and not local_shutdown.is_set(): + if len(pending_task_futures) > 0: + # Records how many activations with pending producer futures + # the worker child has. Only records when there are pending activations. + metrics.gauge( + "taskworker.worker.activations_with_pending_futures", + len(pending_task_futures), + tags={ + "processing_pool": processing_pool_name, + }, + ) + for task in pending_task_futures.copy(): + future_status = [ + f.done() for fut in task.pending_futures.values() for f in fut + ] + if all(future_status): + await_task_futures(task) + else: + # How many futures are still pending in this task + metrics.distribution( + "taskworker.task.incomplete_futures", + len([f for f in future_status if not f]), + tags={ + "processing_pool": processing_pool_name, + "namespace": task.inflight.activation.namespace, + "taskname": task.inflight.activation.taskname, + }, + ) + # How long has the oldest pending task been sitting in the queue + if oldest := get_oldest_pending_activation(): metrics.distribution( - "taskworker.task.incomplete_futures", - len([f for f in future_status if not f]), + "taskworker.worker.oldest_pending_activation_age", + time.time() - oldest.futures_start_time, tags={ "processing_pool": processing_pool_name, - "namespace": task.inflight.activation.namespace, - "taskname": task.inflight.activation.taskname, + "namespace": oldest.inflight.activation.namespace, + "taskname": oldest.inflight.activation.taskname, }, ) - # How long has the oldest pending task been sitting in the queue - if oldest := get_oldest_pending_activation(): - metrics.distribution( - "taskworker.worker.oldest_pending_activation_age", - time.time() - oldest.futures_start_time, - tags={ - "processing_pool": processing_pool_name, - "namespace": oldest.inflight.activation.namespace, - "taskname": oldest.inflight.activation.taskname, - }, - ) + else: + time.sleep(0.1) + + _future_completion_thread = threading.Thread( + name="check-future-completion", + target=partial(check_task_future_completion, shutdown_event, local_shutdown), + daemon=True, + ) + _future_completion_thread.start() while not shutdown_event.is_set() and not local_shutdown.is_set(): if max_task_count and processed_task_count >= max_task_count: @@ -341,10 +356,11 @@ def check_task_future_completion() -> None: logger.info( "taskworker.max_task_count_reached", extra={"count": processed_task_count} ) + # Still set the shutdown signal to trigger shutdown of the future checker thread + local_shutdown.set() break try: - check_task_future_completion() inflight = child_tasks.get(timeout=1.0) except queue.Empty: metrics.incr( @@ -530,6 +546,7 @@ def check_task_future_completion() -> None: pending_task_futures.append(pending_task) # Once we get the shutdown signal, drain any pending futures + _future_completion_thread.join() for task in pending_task_futures.copy(): await_task_futures(task)