Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions router/pkg/pubsub/redis/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,22 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri
sub := p.conn.PSubscribe(ctx, subConf.Channels...)
msgChan := sub.Channel()

cleanup := func() {
err := sub.PUnsubscribe(ctx, subConf.Channels...)
if err != nil {
log.Error(fmt.Sprintf("error unsubscribing from redis for topics %v", subConf.Channels), zap.Error(err))
}
}

p.closeWg.Add(1)

go func() {
defer p.closeWg.Done()

// Always release the dedicated pub/sub connection on every exit path,
// otherwise each subscription leaks a connection. Closing the PubSub
// also unsubscribes from all channels server-side, so an explicit
// PUnsubscribe is unnecessary (and would fail anyway once the context
// driving teardown is cancelled).
defer func() {
if err := sub.Close(); err != nil {
log.Error(fmt.Sprintf("error closing redis subscription for topics %v", subConf.Channels), zap.Error(err))
}
}()

for {
select {
case msg, ok := <-msgChan:
Expand All @@ -137,12 +141,10 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri
case <-p.ctx.Done():
// When the application context is done, we stop the subscription if it is not already done
log.Debug("application context done, stopping subscription")
cleanup()
return
case <-ctx.Done():
// When the subscription context is done, we stop the subscription if it is not already done
log.Debug("subscription context done, stopping subscription")
cleanup()
return
}
}
Expand Down
70 changes: 70 additions & 0 deletions router/pkg/pubsub/redis/adapter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package redis

import (
"context"
"fmt"
"testing"
"time"

"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router/pkg/pubsub/datasource"
"go.uber.org/zap/zaptest"
)

// noopUpdater satisfies datasource.SubscriptionEventUpdater for tests.
type noopUpdater struct{}

func (noopUpdater) Update(events []datasource.StreamEvent) {}
func (noopUpdater) Complete() {}
func (noopUpdater) Done() {}
func (noopUpdater) SetHooks(hooks datasource.Hooks) {}

func newTestAdapter(t *testing.T) (*ProviderAdapter, *miniredis.Miniredis) {
t.Helper()

mr := miniredis.RunT(t)

adapter := NewProviderAdapter(
context.Background(),
zaptest.NewLogger(t),
[]string{fmt.Sprintf("redis://%s", mr.Addr())},
false,
datasource.ProviderOpts{},
)
require.NoError(t, adapter.Startup(context.Background()))
t.Cleanup(func() { _ = adapter.Shutdown(context.Background()) })

return adapter.(*ProviderAdapter), mr
}

// TestProviderAdapter_Subscribe_ReleasesConnectionOnCancel guards against the
// connection leak where each subscription's dedicated pub/sub connection was
// never closed (only PUnsubscribe was called, with an already-cancelled
// context). After every subscription context is cancelled, the pool's total
// connection count must return to its pre-subscribe baseline.
func TestProviderAdapter_Subscribe_ReleasesConnectionOnCancel(t *testing.T) {
p, _ := newTestAdapter(t)

baseline := p.conn.PoolStats().TotalConns

const subscriptions = 10
for i := 0; i < subscriptions; i++ {
subCtx, cancel := context.WithCancel(context.Background())
err := p.Subscribe(subCtx, &SubscriptionEventConfiguration{
Provider: "test-provider",
Channels: []string{fmt.Sprintf("channel-%d", i)},
}, noopUpdater{})
require.NoError(t, err)

// Cancelling the subscription context must tear the subscription down
// and release its dedicated connection.
cancel()
}

require.Eventually(t, func() bool {
return p.conn.PoolStats().TotalConns <= baseline
}, 5*time.Second, 10*time.Millisecond,
"redis pub/sub connections leaked: have %d, baseline %d",
p.conn.PoolStats().TotalConns, baseline)
}
Loading