From e7f9162585e786b82b86187e6a98e041638ae039 Mon Sep 17 00:00:00 2001 From: Robin Hahling Date: Wed, 1 Apr 2026 16:24:33 +0200 Subject: [PATCH 1/2] add WithResultCallback option to avoid unbounded result accumulation Task results have always accumulated in memory until Drain is called, requiring callers to drain periodically to avoid unbounded growth. This change adds WithResultCallback, an option that processes each result as it completes, eliminating accumulation entirely. This change is backwards compatible. New and NewWithContext accept variadic Option arguments, so existing callers require no changes. Signed-off-by: Robin Hahling --- README.md | 64 +++++++++++++++++++++++-- example_test.go | 30 ++++++++++++ task.go | 26 +++++++--- workerpool.go | 67 +++++++++++++++++++++----- workerpool_test.go | 116 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 280 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index b073964..aa11dab 100644 --- a/README.md +++ b/README.md @@ -17,13 +17,15 @@ behavior. One caveat is that while the number of concurrently running workers is limited, task results are not and they accumulate until they are collected. Therefore, if a large number of tasks can be expected, the workerpool should be -periodically drained (e.g. every 10k tasks). +periodically drained (e.g. every 10k tasks). Alternatively, +`WithResultCallback` can be used to process results as they complete, avoiding +accumulation entirely. This package is mostly useful when tasks are CPU bound and spawning too many routines would be detrimental to performance. It features a straightforward API -and no external dependencies. See the section below for a usage example. +and no external dependencies. See the sections below for usage examples. -## Example +## Example with Drain ```go package main @@ -63,8 +65,9 @@ func main() { } return nil }) - // Submit fails when the pool is closed (ErrClosed) or being drained - // (ErrDrained). Check for the error when appropriate. + // Submit fails when the pool is closed (ErrClosed), being drained + // (ErrDraining), or the parent context is done (context.Canceled). + // Check for the error when appropriate. if err != nil { fmt.Fprintln(os.Stderr, err) return @@ -93,3 +96,54 @@ func main() { } } ``` + +## Example with result callback + +Use `WithResultCallback` to process each result as it completes rather than +accumulating them for a later `Drain` call. The callback receives a `Result`, +which extends `Task` with a `Duration()` method reporting how long the task +took to execute. This is useful for logging, metrics, or long-running pools +where unbounded result accumulation is undesirable. + +```go +package main + +import ( + "context" + "fmt" + "log" + "os" + "runtime" + + "github.com/cilium/workerpool" +) + +func main() { + wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(r workerpool.Result) { + if err := r.Err(); err != nil { + fmt.Fprintf(os.Stderr, "task %s failed after %s: %v\n", r, r.Duration(), err) + } else { + fmt.Printf("task %s completed in %s\n", r, r.Duration()) + } + })) + + for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 { + id := fmt.Sprintf("task #%d", i) + err := wp.Submit(id, func(_ context.Context) error { + if IsPrime(n) { + fmt.Println(n, "is prime!") + } + return nil + }) + if err != nil { + log.Fatal(err) + } + } + + // Close waits for all in-flight tasks to complete before returning, + // ensuring all callback invocations have finished. + if err := wp.Close(); err != nil { + log.Fatal(err) + } +} +``` diff --git a/example_test.go b/example_test.go index 2504f34..e04c4ac 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,7 @@ package workerpool_test import ( "context" "fmt" + "log" "os" "runtime" @@ -67,3 +68,32 @@ func Example() { fmt.Fprintln(os.Stderr, err) } } + +func ExampleWithResultCallback() { + wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(r workerpool.Result) { + if err := r.Err(); err != nil { + fmt.Fprintf(os.Stderr, "task %s failed after %s: %v\n", r, r.Duration(), err) + } else { + fmt.Printf("task %s completed in %s\n", r, r.Duration()) + } + })) + + for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 { + id := fmt.Sprintf("task #%d", i) + err := wp.Submit(id, func(_ context.Context) error { + if IsPrime(n) { + fmt.Println(n, "is prime!") + } + return nil + }) + if err != nil { + log.Fatal(err) + } + } + + // Close waits for all in-flight tasks to complete before returning, + // ensuring all callback invocations have finished. + if err := wp.Close(); err != nil { + log.Fatal(err) + } +} diff --git a/task.go b/task.go index 4aa3eaf..c7a3207 100644 --- a/task.go +++ b/task.go @@ -6,6 +6,7 @@ package workerpool import ( "context" "fmt" + "time" ) // Task is a unit of work. @@ -17,26 +18,39 @@ type Task interface { Err() error } +// Result is a completed Task that also reports its execution duration. +// It is passed to the callback registered with WithResultCallback. +type Result interface { + Task + // Duration returns the time taken to execute the task. + Duration() time.Duration +} + type task struct { run func(context.Context) error id string } type taskResult struct { - err error - id string + err error + id string + duration time.Duration } -// Ensure that taskResult implements the Task interface. -var _ Task = &taskResult{} +// Ensure that taskResult implements the Result interface. +var _ Result = &taskResult{} // String implements fmt.Stringer for taskResult. func (t *taskResult) String() string { return t.id } -// Err returns the error resulting from processing the taskResult. It ensures -// that the taskResult struct implements the Task interface. +// Err returns the error resulting from processing the taskResult. func (t *taskResult) Err() error { return t.err } + +// Duration returns the time taken to execute the task. +func (t *taskResult) Duration() time.Duration { + return t.duration +} diff --git a/workerpool.go b/workerpool.go index 457fa9a..91080b1 100644 --- a/workerpool.go +++ b/workerpool.go @@ -15,6 +15,8 @@ // limited, task results are not and they accumulate until they are collected. // Therefore, if a large number of tasks can be expected, the workerpool should // be periodically drained (e.g. every 10k tasks). +// Alternatively, use WithResultCallback to process results as they complete +// without accumulation. package workerpool import ( @@ -22,6 +24,7 @@ import ( "errors" "fmt" "sync" + "time" ) var ( @@ -30,18 +33,41 @@ var ( ErrDraining = errors.New("drain operation in progress") // ErrClosed is returned when operations are attempted after a call to Close. ErrClosed = errors.New("worker pool is closed") + // ErrCallbackSet is returned by Drain when a result callback has been + // registered via WithResultCallback. + ErrCallbackSet = errors.New("a result callback is set") ) +// Option configures a WorkerPool. +type Option func(*WorkerPool) + +// WithResultCallback registers fn to be called each time a task completes. +// When a callback is set, results are not accumulated internally and Drain +// returns ErrCallbackSet. The callback may be invoked concurrently from +// multiple goroutines; fn must be safe for concurrent use. +// WithResultCallback panics if fn is nil. +func WithResultCallback(fn func(Result)) Option { + // TODO(v2): New/NewWithContext should return an error so that option + // validation can propagate errors instead of panicking. + if fn == nil { + panic("workerpool.WithResultCallback: fn must not be nil") + } + return func(wp *WorkerPool) { + wp.onResult = fn + } +} + // WorkerPool spawns, on demand, a number of worker routines to process // submitted tasks concurrently. The number of concurrent routines never // exceeds the specified limit. type WorkerPool struct { - workers chan struct{} - tasks chan *task - done <-chan struct{} - cancel context.CancelFunc - results []Task - wg sync.WaitGroup + workers chan struct{} + tasks chan *task + done <-chan struct{} + cancel context.CancelFunc + onResult func(Result) + results []Task + wg sync.WaitGroup mu sync.Mutex draining bool @@ -50,13 +76,13 @@ type WorkerPool struct { // New creates a new pool of workers where at most n workers process submitted // tasks concurrently. New panics if n ≤ 0. -func New(n int) *WorkerPool { - return NewWithContext(context.Background(), n) +func New(n int, opts ...Option) *WorkerPool { + return NewWithContext(context.Background(), n, opts...) } // NewWithContext creates a new pool of workers where at most n workers process submitted // tasks concurrently. New panics if n ≤ 0. The context is used as the parent context to the context of the task func passed to Submit. -func NewWithContext(ctx context.Context, n int) *WorkerPool { +func NewWithContext(ctx context.Context, n int, opts ...Option) *WorkerPool { if n <= 0 { panic(fmt.Sprintf("workerpool.New: n must be > 0, got %d", n)) } @@ -67,6 +93,9 @@ func NewWithContext(ctx context.Context, n int) *WorkerPool { ctx, cancel := context.WithCancel(ctx) wp.cancel = cancel wp.done = ctx.Done() + for _, opt := range opts { + opt(wp) + } go wp.run(ctx) return wp } @@ -124,6 +153,7 @@ func (wp *WorkerPool) Submit(id string, f func(ctx context.Context) error) error // tasks that have been processed. // If a drain operation is already in progress, ErrDraining is returned. // If the worker pool is closed, ErrClosed is returned. +// If a result callback is set via WithResultCallback, ErrCallbackSet is returned. func (wp *WorkerPool) Drain() ([]Task, error) { wp.mu.Lock() if wp.closed { @@ -134,6 +164,11 @@ func (wp *WorkerPool) Drain() ([]Task, error) { wp.mu.Unlock() return nil, ErrDraining } + // TODO(v2): remove ErrCallbackSet — a pool configured with WithResultCallback should not expose Drain. + if wp.onResult != nil { + wp.mu.Unlock() + return nil, ErrCallbackSet + } wp.draining = true wp.mu.Unlock() @@ -154,7 +189,9 @@ func (wp *WorkerPool) Drain() ([]Task, error) { // Close closes the worker pool, rendering it unable to process new tasks. // Close sends the cancellation signal to any running task and waits for all -// workers, if any, to return. +// workers, if any, to return. When a result callback is set via +// WithResultCallback, all callback invocations are guaranteed to have completed +// before Close returns. // Close will return ErrClosed if it has already been called. func (wp *WorkerPool) Close() error { wp.mu.Lock() @@ -181,15 +218,21 @@ func (wp *WorkerPool) Close() error { // only be called once during the lifetime of a WorkerPool. func (wp *WorkerPool) run(ctx context.Context) { for t := range wp.tasks { - result := taskResult{id: t.id} - wp.results = append(wp.results, &result) + if wp.onResult == nil { + wp.results = append(wp.results, &result) + } wp.workers <- struct{}{} go func() { defer wp.wg.Done() + start := time.Now() if t.run != nil { result.err = t.run(ctx) } + result.duration = time.Since(start) + if wp.onResult != nil { + wp.onResult(&result) + } <-wp.workers }() } diff --git a/workerpool_test.go b/workerpool_test.go index 66cbe5b..085f32c 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -15,6 +15,8 @@ import ( "github.com/cilium/workerpool" ) +var errTask = errors.New("task error") + func TestWorkerPoolNewPanics(t *testing.T) { // helper expecting New(n) to panic. testWorkerPoolNewPanics := func(n int) { @@ -30,6 +32,15 @@ func TestWorkerPoolNewPanics(t *testing.T) { testWorkerPoolNewPanics(-1) } +func TestWithResultCallbackNilPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("WithResultCallback(nil) should panic()") + } + }() + workerpool.WithResultCallback(nil) +} + func TestWorkerPoolTasksCapacity(t *testing.T) { wp := workerpool.New(runtime.NumCPU()) defer func() { @@ -339,6 +350,21 @@ func TestWorkerPoolDrainAfterClose(t *testing.T) { } } +func TestWorkerPoolDrainAfterCloseWithCallback(t *testing.T) { + wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(workerpool.Result) {})) + if err := wp.Close(); err != nil { + t.Fatalf("close: got '%v', want no error", err) + } + // ErrClosed must take precedence over ErrCallbackSet. + tasks, err := wp.Drain() + if !errors.Is(err, workerpool.ErrClosed) { + t.Errorf("got %v; want %v", err, workerpool.ErrClosed) + } + if tasks != nil { + t.Errorf("got %v as tasks; want %v", tasks, nil) + } +} + func TestWorkerPoolSubmitNil(t *testing.T) { wp := workerpool.New(runtime.NumCPU()) defer func() { @@ -367,6 +393,32 @@ func TestWorkerPoolSubmitNil(t *testing.T) { } +func TestWorkerPoolSubmitNilWithCallback(t *testing.T) { + id := "nothing" + var got workerpool.Result + wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(r workerpool.Result) { + got = r + })) + if err := wp.Submit(id, nil); err != nil { + t.Fatalf("got %v; want no error", err) + } + if err := wp.Close(); err != nil { + t.Fatalf("close: got '%v', want no error", err) + } + if got == nil { + t.Fatal("callback was not invoked") + } + if s := got.String(); s != id { + t.Errorf("String: got '%s', want '%s'", s, id) + } + if err := got.Err(); err != nil { + t.Errorf("Err: got '%v', want no error", err) + } + if got.Duration() < 0 { + t.Errorf("Duration: got %v, want >= 0", got.Duration()) + } +} + func TestWorkerPoolSubmitAfterClose(t *testing.T) { wp := workerpool.New(runtime.NumCPU()) if err := wp.Close(); err != nil { @@ -512,3 +564,67 @@ func TestWorkerPoolNewWithCancelledContext(t *testing.T) { t.Errorf("drain: got %d results, want 0", len(results)) } } + +func TestWorkerPoolWithResultCallback(t *testing.T) { + n := runtime.NumCPU() + + var mu sync.Mutex + var got []workerpool.Result + + wp := workerpool.New(n, workerpool.WithResultCallback(func(r workerpool.Result) { + mu.Lock() + defer mu.Unlock() + got = append(got, r) + })) + + numTasks := n + 2 + wantErr := errTask + for i := range numTasks { + id := fmt.Sprintf("task #%2d", i) + var f func(context.Context) error + if i == 0 { + f = func(_ context.Context) error { return wantErr } + } else { + f = func(_ context.Context) error { return nil } + } + if err := wp.Submit(id, f); err != nil { + t.Fatalf("failed to submit task '%s': %v", id, err) + } + } + + // Drain must return ErrCallbackSet. + tasks, err := wp.Drain() + if !errors.Is(err, workerpool.ErrCallbackSet) { + t.Errorf("drain: got %v, want %v", err, workerpool.ErrCallbackSet) + } + if tasks != nil { + t.Errorf("drain: got %v, want nil", tasks) + } + + // Close waits for all in-flight tasks, so after it returns all callbacks + // have been invoked. + if err := wp.Close(); err != nil { + t.Fatalf("close: got '%v', want no error", err) + } + + mu.Lock() + defer mu.Unlock() + + if len(got) != numTasks { + t.Fatalf("callback: got %d results, want %d", len(got), numTasks) + } + for _, r := range got { + if r.Duration() < 0 { + t.Errorf("%s: Duration: got %v, want >= 0", r, r.Duration()) + } + if r.String() == "task # 0" { + if !errors.Is(r.Err(), wantErr) { + t.Errorf("%s: Err: got %v, want %v", r, r.Err(), wantErr) + } + } else { + if r.Err() != nil { + t.Errorf("%s: Err: got %v, want nil", r, r.Err()) + } + } + } +} From f7b8f74604087b1e5b402c02af940060c079bb4f Mon Sep 17 00:00:00 2001 From: Robin Hahling Date: Wed, 1 Apr 2026 19:37:47 +0200 Subject: [PATCH 2/2] doc: make the for loop condition in the examples more readable Suggested-by: Alexandre Perrin Signed-off-by: Robin Hahling --- README.md | 4 ++-- example_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index aa11dab..58b1694 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ func IsPrime(n int64) bool { func main() { wp := workerpool.New(runtime.NumCPU()) - for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 { + for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 { id := fmt.Sprintf("task #%d", i) // Use Submit to submit tasks for processing. Submit blocks when no // worker is available to pick up the task. @@ -127,7 +127,7 @@ func main() { } })) - for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 { + for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 { id := fmt.Sprintf("task #%d", i) err := wp.Submit(id, func(_ context.Context) error { if IsPrime(n) { diff --git a/example_test.go b/example_test.go index e04c4ac..d33b26c 100644 --- a/example_test.go +++ b/example_test.go @@ -28,7 +28,7 @@ func IsPrime(n int64) bool { func Example() { wp := workerpool.New(runtime.NumCPU()) - for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 { + for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 { id := fmt.Sprintf("task #%d", i) // Use Submit to submit tasks for processing. Submit blocks when no // worker is available to pick up the task. @@ -78,7 +78,7 @@ func ExampleWithResultCallback() { } })) - for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 { + for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 { id := fmt.Sprintf("task #%d", i) err := wp.Submit(id, func(_ context.Context) error { if IsPrime(n) {