From a2cecd6f8e50ae8028bdc7db8217f38ad37e6a4d Mon Sep 17 00:00:00 2001 From: Alessandro Pagnin Date: Thu, 4 Jun 2026 13:44:20 +0200 Subject: [PATCH 1/5] feat: allow dynamically changing subscription configuration from custom hook --- .../modules/start_subscription_change_test.go | 119 ++++++++++++++++++ router/core/subscriptions_modules.go | 23 +++- router/core/subscriptions_modules_test.go | 58 +++++++++ router/pkg/pubsub/datasource/hooks.go | 2 +- .../datasource/subscription_datasource.go | 73 ++++++++++- .../subscription_datasource_test.go | 104 ++++++++++++--- 6 files changed, 356 insertions(+), 23 deletions(-) create mode 100644 router-tests/modules/start_subscription_change_test.go create mode 100644 router/core/subscriptions_modules_test.go diff --git a/router-tests/modules/start_subscription_change_test.go b/router-tests/modules/start_subscription_change_test.go new file mode 100644 index 0000000000..3d64fbd56a --- /dev/null +++ b/router-tests/modules/start_subscription_change_test.go @@ -0,0 +1,119 @@ +package module_test + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + start_subscription "github.com/wundergraph/cosmo/router-tests/modules/start-subscription" + "go.uber.org/zap/zapcore" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/redis" +) + +type subscriptionArgs struct { + dataValue []byte + errValue error +} + +func TestStartSubscriptionChangeHook(t *testing.T) { + t.Run("Test StartSubscription hook can change channel", func(t *testing.T) { + t.Parallel() + logicalChannel := "customRedisChannel" + newChannel := "" + + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + redisCfg, ok := ctx.SubscriptionEventConfiguration().(*redis.SubscriptionEventConfiguration) + if ok { + redisCfg.Channels = []string{newChannel} + ctx.SetSubscriptionEventConfiguration(redisCfg) + } + return nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": customModule, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsRedisJSONTemplate, + EnableRedis: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + newChannel = xEnv.GetPubSubName(logicalChannel) + + var subscriptionOne struct { + employeeUpdatedMyRedis struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyRedis(id: 2)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan subscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- subscriptionArgs{dataValue, errValue} + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForTriggerCount(1, time.Second*10) + + // produce a message (retry until subscription pipeline is confirmed active) + xEnv.RedisPublishUntilReceived(newChannel, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`, 10*time.Second) + + // The SubscriptionOnStart hook may be called asynchronously after + // WaitForSubscriptionCount returns, so poll until it fires. + require.Eventually(t, func() bool { + return customModule.HookCallCount.Load() >= 1 + }, time.Second*10, time.Millisecond*50) + + // process the message + select { + case subscriptionArgs := <-subscriptionArgsCh: + require.NoError(t, subscriptionArgs.errValue) + require.JSONEq(t, `{"employeeUpdatedMyRedis":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(subscriptionArgs.dataValue)) + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for first message error") + } + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) + }) + }) +} diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index c01d1fa348..c0915427a5 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -23,6 +23,9 @@ type SubscriptionOnStartHandlerContext interface { Authentication() authentication.Authentication // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // SetSubscriptionEventConfiguration replaces the event configuration used to start + // a pub/sub subscription. It returns false for engine subscriptions or nil configs. + SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool // EmitEvent sends an event directly to the subscription stream of the // currently connected client. // @@ -115,6 +118,14 @@ func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() return c.subscriptionEventConfiguration } +func (c *pubSubSubscriptionOnStartHookContext) SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool { + if config == nil { + return false + } + c.subscriptionEventConfiguration = config + return true +} + func (c *pubSubSubscriptionOnStartHookContext) EmitEvent(event datasource.StreamEvent) bool { c.emitEventFn(event.GetData()) @@ -199,6 +210,10 @@ func (c *engineSubscriptionOnStartHookContext) SubscriptionEventConfiguration() return nil } +func (c *engineSubscriptionOnStartHookContext) SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool { + return false +} + type SubscriptionOnStartHandler interface { // SubscriptionOnStart is called once at subscription start // The error is propagated to the client. @@ -211,7 +226,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont return nil } - return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn) error { + return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn) (datasource.SubscriptionEventConfiguration, error) { requestContext := getRequestContext(resolveCtx.Context) logger := requestContext.Logger() @@ -236,7 +251,11 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont eventBuilder: eventBuilder, } - return fn(hookCtx) + if err := fn(hookCtx); err != nil { + return nil, err + } + + return hookCtx.subscriptionEventConfiguration, nil } } diff --git a/router/core/subscriptions_modules_test.go b/router/core/subscriptions_modules_test.go new file mode 100644 index 0000000000..da3014c2f3 --- /dev/null +++ b/router/core/subscriptions_modules_test.go @@ -0,0 +1,58 @@ +package core + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + rcontext "github.com/wundergraph/cosmo/router/internal/context" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type testSubscriptionEventConfig struct { + providerID string + fieldName string +} + +func (c *testSubscriptionEventConfig) ProviderID() string { + return c.providerID +} + +func (c *testSubscriptionEventConfig) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeRedis +} + +func (c *testSubscriptionEventConfig) RootFieldName() string { + return c.fieldName +} + +func TestNewPubSubSubscriptionOnStartHookReturnsUpdatedSubscriptionEventConfiguration(t *testing.T) { + originalConfig := &testSubscriptionEventConfig{ + providerID: "provider", + fieldName: "original", + } + updatedConfig := &testSubscriptionEventConfig{ + providerID: "provider", + fieldName: "updated", + } + + hook := NewPubSubSubscriptionOnStartHook(func(ctx SubscriptionOnStartHandlerContext) error { + require.Same(t, originalConfig, ctx.SubscriptionEventConfiguration()) + require.True(t, ctx.SetSubscriptionEventConfiguration(updatedConfig)) + return nil + }) + + req := httptest.NewRequest("GET", "/graphql", nil) + reqCtx := buildRequestContext(requestContextOptions{r: req}) + ctx := context.WithValue(req.Context(), rcontext.RequestContextKey, reqCtx) + + actualConfig, err := hook(resolve.StartupHookContext{ + Context: ctx, + Updater: func(data []byte) {}, + }, originalConfig, nil) + + require.NoError(t, err) + require.Same(t, updatedConfig, actualConfig) +} diff --git a/router/pkg/pubsub/datasource/hooks.go b/router/pkg/pubsub/datasource/hooks.go index a262058463..545e97e8fe 100644 --- a/router/pkg/pubsub/datasource/hooks.go +++ b/router/pkg/pubsub/datasource/hooks.go @@ -7,7 +7,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error +type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) type OnPublishEventsFn func(ctx context.Context, pubConf PublishEventConfiguration, evts []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index 939c03f94d..3a108aef4e 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -1,12 +1,15 @@ package datasource import ( + "context" "encoding/json" "errors" "fmt" "net/http" + "strconv" "github.com/cespare/xxhash/v2" + rcontext "github.com/wundergraph/cosmo/router/internal/context" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -16,6 +19,13 @@ type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Dige type EventBuilderFn func(data []byte) MutableStreamEvent +const subscriptionEventConfigurationContextKeyPrefix = "wg.cosmo.pubsub.subscription_event_configuration." + +type subscriptionEventConfigurationStore interface { + Set(key string, value any) + Get(key string) (value any, exists bool) +} + // PubSubSubscriptionDataSource is a data source for handling subscriptions using a Pub/Sub mechanism. // It implements the SubscriptionDataSource interface and HookableSubscriptionDataSource type PubSubSubscriptionDataSource[C SubscriptionEventConfiguration] struct { @@ -32,8 +42,15 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionEventConfiguration(input [ return subscriptionConfiguration, err } +func (s *PubSubSubscriptionDataSource[C]) subscriptionEventConfiguration(ctx context.Context, input []byte) (SubscriptionEventConfiguration, error) { + if conf, ok := subscriptionEventConfigurationFromContext(ctx, input); ok { + return conf, nil + } + return s.SubscriptionEventConfiguration(input) +} + func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, header http.Header, input []byte, updater resolve.SubscriptionUpdater) error { - subConf, err := s.SubscriptionEventConfiguration(input) + subConf, err := s.subscriptionEventConfiguration(ctx.Context(), input) if err != nil { return err } @@ -70,17 +87,26 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.Startu } }() + if len(s.hooks.SubscriptionOnStart.Handlers) == 0 { + return nil + } + + conf, err := s.SubscriptionEventConfiguration(input) + if err != nil { + return err + } + for _, fn := range s.hooks.SubscriptionOnStart.Handlers { - conf, err := s.SubscriptionEventConfiguration(input) + conf, err = fn(ctx, conf, s.eventBuilder) if err != nil { return err } - err = fn(ctx, conf, s.eventBuilder) - if err != nil { - return err + if conf == nil { + return errors.New("invalid subscription configuration") } } + setSubscriptionEventConfiguration(ctx.Context, input, conf) return nil } @@ -102,3 +128,40 @@ func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Ad eventBuilder: eventBuilder, } } + +func subscriptionEventConfigurationContextKey(input []byte) string { + return subscriptionEventConfigurationContextKeyPrefix + + strconv.Itoa(len(input)) + ":" + + strconv.FormatUint(xxhash.Sum64(input), 16) +} + +func requestContextStore(ctx context.Context) subscriptionEventConfigurationStore { + if ctx == nil { + return nil + } + store, _ := ctx.Value(rcontext.RequestContextKey).(subscriptionEventConfigurationStore) + return store +} + +func setSubscriptionEventConfiguration(ctx context.Context, input []byte, conf SubscriptionEventConfiguration) { + store := requestContextStore(ctx) + if store == nil { + return + } + store.Set(subscriptionEventConfigurationContextKey(input), conf) +} + +func subscriptionEventConfigurationFromContext(ctx context.Context, input []byte) (SubscriptionEventConfiguration, bool) { + store := requestContextStore(ctx) + if store == nil { + return nil, false + } + + value, ok := store.Get(subscriptionEventConfigurationContextKey(input)) + if !ok || value == nil { + return nil, false + } + + conf, ok := value.(SubscriptionEventConfiguration) + return conf, ok +} diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index 2662a2a2ca..141d611276 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -4,12 +4,14 @@ import ( "context" "encoding/json" "errors" + "sync" "testing" "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + rcontext "github.com/wundergraph/cosmo/router/internal/context" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -32,6 +34,27 @@ func (t testSubscriptionEventConfiguration) RootFieldName() string { return "testSubscription" } +type testRequestContextStore struct { + mu sync.RWMutex + values map[string]any +} + +func (s *testRequestContextStore) Set(key string, value any) { + s.mu.Lock() + defer s.mu.Unlock() + if s.values == nil { + s.values = make(map[string]any) + } + s.values[key] = value +} + +func (s *testRequestContextStore) Get(key string) (value any, exists bool) { + s.mu.RLock() + defer s.mu.RUnlock() + value, exists = s.values[key] + return value, exists +} + // testSubscriptionDataSourceEventBuilder is a reusable event builder for tests func testSubscriptionDataSourceEventBuilder(data []byte) MutableStreamEvent { return mutableTestEvent(data) @@ -101,6 +124,57 @@ func TestPubSubSubscriptionDataSource_Start_Success(t *testing.T) { mockAdapter.AssertExpectations(t) } +func TestPubSubSubscriptionDataSource_Start_UsesSubscriptionOnStartConfiguration(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[*testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) + + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + typedConfig, ok := config.(*testSubscriptionEventConfiguration) + require.True(t, ok) + typedConfig.Topic = "changed-topic" + return typedConfig, nil + } + + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook}, + }, + }) + + testConfig := &testSubscriptionEventConfiguration{ + Topic: "original-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + require.NoError(t, err) + + requestStore := &testRequestContextStore{} + requestCtx := context.WithValue(context.Background(), rcontext.RequestContextKey, requestStore) + + err = dataSource.SubscriptionOnStart(resolve.StartupHookContext{ + Context: requestCtx, + Updater: func(data []byte) {}, + }, input) + require.NoError(t, err) + + ctx := resolve.NewContext(requestCtx) + mockUpdater := NewMockSubscriptionUpdater(t) + expectedConfig := &testSubscriptionEventConfiguration{ + Topic: "changed-topic", + Subject: "test-subject", + } + + mockAdapter.On("Subscribe", ctx.Context(), expectedConfig, mock.AnythingOfType("*datasource.subscriptionEventUpdater")).Return(nil) + + err = dataSource.Start(ctx, nil, input, mockUpdater) + require.NoError(t, err) + mockAdapter.AssertExpectations(t) +} + func TestPubSubSubscriptionDataSource_Start_NoConfiguration(t *testing.T) { mockAdapter := NewMockProvider(t) uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { @@ -183,20 +257,20 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T hook1EventBuilderExists := false hook2EventBuilderExists := false - hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { hook1Called = true if eventBuilder != nil { hook1EventBuilderExists = true } - return nil + return config, nil } - hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { hook2Called = true if eventBuilder != nil { hook2EventBuilderExists = true } - return nil + return config, nil } dataSource.SetHooks(Hooks{ @@ -234,8 +308,8 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsClose(t *te dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Add hook that returns close=true - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { - return nil + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + return config, nil } dataSource.SetHooks(Hooks{ @@ -270,8 +344,8 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *te expectedError := errors.New("hook error") // Add hook that returns an error - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { - return expectedError + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + return nil, expectedError } dataSource.SetHooks(Hooks{ @@ -309,11 +383,11 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { assert.Len(t, dataSource.hooks.SubscriptionOnStart.Handlers, 0) // Add hooks - hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { - return nil + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + return config, nil } - hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { - return nil + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + return config, nil } dataSource.SetHooks(Hooks{ @@ -369,9 +443,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_InvalidEventConfigInpu dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) hookCalled := false - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { hookCalled = true - return nil + return config, nil } dataSource.SetHooks(Hooks{ @@ -430,7 +504,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_PanicRecovery(t *testi dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Add hook that panics - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { panic(tt.panicValue) } From c6033b698aad30a20777ae90eeb1182e39f47091 Mon Sep 17 00:00:00 2001 From: Alessandro Pagnin Date: Thu, 4 Jun 2026 13:47:54 +0200 Subject: [PATCH 2/5] fix: redis publish was double prefixing --- router-tests/modules/start_subscription_change_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router-tests/modules/start_subscription_change_test.go b/router-tests/modules/start_subscription_change_test.go index 3d64fbd56a..e5acf220eb 100644 --- a/router-tests/modules/start_subscription_change_test.go +++ b/router-tests/modules/start_subscription_change_test.go @@ -91,7 +91,7 @@ func TestStartSubscriptionChangeHook(t *testing.T) { xEnv.WaitForTriggerCount(1, time.Second*10) // produce a message (retry until subscription pipeline is confirmed active) - xEnv.RedisPublishUntilReceived(newChannel, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`, 10*time.Second) + xEnv.RedisPublishUntilReceived(logicalChannel, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`, 10*time.Second) // The SubscriptionOnStart hook may be called asynchronously after // WaitForSubscriptionCount returns, so poll until it fires. From eef97292b042a7fd27e2e9f78fa9202eef10a08f Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 4 Jun 2026 17:05:17 +0200 Subject: [PATCH 3/5] fix: Return event config clone --- router/cmd/custom/module/module.go | 121 ++++-------------- router/core/subscriptions_modules.go | 11 +- router/core/subscriptions_modules_test.go | 45 ++++++- router/pkg/pubsub/datasource/provider.go | 4 + .../pubsub/datasource/pubsubprovider_test.go | 5 + .../subscription_datasource_test.go | 4 + .../subscription_event_updater_test.go | 5 + router/pkg/pubsub/kafka/engine_datasource.go | 7 + router/pkg/pubsub/nats/engine_datasource.go | 11 ++ router/pkg/pubsub/redis/engine_datasource.go | 7 + 10 files changed, 121 insertions(+), 99 deletions(-) diff --git a/router/cmd/custom/module/module.go b/router/cmd/custom/module/module.go index b7ee8e6d45..788ddc7449 100644 --- a/router/cmd/custom/module/module.go +++ b/router/cmd/custom/module/module.go @@ -1,119 +1,46 @@ package module import ( - "fmt" - "net/http" - "github.com/wundergraph/cosmo/router/core" - "go.uber.org/zap" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/redis" ) -const myModuleID = "myModule" - -// MyModule is a simple module that has access to the GraphQL operation and add a header to the response -// It demonstrates how to use the different handlers to customize the router. -// It also shows how to use the config file to configure and validate your module config. -// By default, the config file is located at `config.yaml` in the working directory of the router. -type MyModule struct { - // Properties that are set by the config file are automatically populated based on the `mapstructure` tag - // Create a new section under `modules.` in the config file with the same name as your module. - // Don't forget in Go the first letter of a property must be uppercase to be exported - - Value uint64 `mapstructure:"value"` - - Logger *zap.Logger +func init() { + core.RegisterModule(&CosmoStreamsModule{}) } -func (m *MyModule) Provision(ctx *core.ModuleContext) error { - // Provision your module here, validate config etc. +type CosmoStreamsModule struct{} - if m.Value == 0 { - ctx.Logger.Error("Value must be greater than 0") - return fmt.Errorf("value must be greater than 0") +func (m *CosmoStreamsModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: "cosmoStreamsModule", + Priority: 1, + New: func() core.Module { + return &CosmoStreamsModule{} + }, } - - // Assign the logger to the module for non-request related logging - m.Logger = ctx.Logger - - return nil -} - -func (m *MyModule) Cleanup() error { - // Shutdown your module here, close connections etc. - - return nil -} - -func (m *MyModule) OnOriginResponse(response *http.Response, ctx core.RequestContext) *http.Response { - // Return a new response or nil if you want to pass it to the next handler - // If you want to modify the response, return a new response - - // Access the custom value set in OnOriginRequest - value := ctx.GetString("myValue") - - fmt.Println("SharedValue", value) - - fmt.Println("OnOriginResponse", response.Request.URL, response.StatusCode) - - return nil } -func (m *MyModule) OnOriginRequest(request *http.Request, ctx core.RequestContext) (*http.Request, *http.Response) { - // Return the modified request or nil if you want to pass it to the next handler - // Return a new response if you want to abort the request and return a custom response - - // Set a header on all origin requests - request.Header.Set("myHeader", ctx.GetString("myValue")) - - // Set a custom value on the request context. See OnOriginResponse - ctx.Set("myValue", "myValue") - - return request, nil +func (m *CosmoStreamsModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return events, nil } -func (m *MyModule) RouterOnRequest(ctx core.RequestContext, next http.Handler) { - logger := ctx.Logger() - logger.Info("Test RouterOnRequest custom module logs") - - next.ServeHTTP(ctx.ResponseWriter(), ctx.Request()) +func (m *CosmoStreamsModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return events, nil } -func (m *MyModule) Middleware(ctx core.RequestContext, next http.Handler) { - - operation := ctx.Operation() +func (m *CosmoStreamsModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHandlerContext) error { + cfg := ctx.SubscriptionEventConfiguration().(*redis.SubscriptionEventConfiguration) + cfg.Channels = []string{"test123"} + //ctx.SetSubscriptionEventConfiguration(cfg) - logger := ctx.Logger() - logger.Info("Test custom module logs") - // Access the GraphQL operation context - fmt.Println( - operation.Name(), - operation.Type(), - operation.Hash(), - operation.Content(), - ) - - // Call the next handler in the chain or return early by calling w.Write() - next.ServeHTTP(ctx.ResponseWriter(), ctx.Request()) -} - -func (m *MyModule) Module() core.ModuleInfo { - return core.ModuleInfo{ - // This is the ID of your module, it must be unique - ID: myModuleID, - // The priority of your module, lower numbers are executed first - Priority: 1, - New: func() core.Module { - return &MyModule{} - }, - } + return nil } // Interface guard var ( - _ core.RouterMiddlewareHandler = (*MyModule)(nil) - _ core.RouterOnRequestHandler = (*MyModule)(nil) - _ core.EnginePreOriginHandler = (*MyModule)(nil) - _ core.EnginePostOriginHandler = (*MyModule)(nil) - _ core.Provisioner = (*MyModule)(nil) - _ core.Cleaner = (*MyModule)(nil) + _ core.StreamPublishEventHandler = (*CosmoStreamsModule)(nil) + _ core.StreamReceiveEventHandler = (*CosmoStreamsModule)(nil) + _ core.SubscriptionOnStartHandler = (*CosmoStreamsModule)(nil) ) diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index c0915427a5..6916853f29 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -115,6 +115,11 @@ func (c *pubSubSubscriptionOnStartHookContext) Authentication() authentication.A } func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { + if c.subscriptionEventConfiguration == nil { + return nil + } + // Return a deep copy so callers cannot mutate the live configuration in + // place. Changes are only applied when passed back via SetSubscriptionEventConfiguration. return c.subscriptionEventConfiguration } @@ -420,7 +425,11 @@ func (c *pubSubStreamReceiveEventHookContext) Authentication() authentication.Au } func (c *pubSubStreamReceiveEventHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { - return c.subscriptionEventConfiguration + if c.subscriptionEventConfiguration == nil { + return nil + } + // Return a deep copy so callers cannot mutate the live configuration in place. + return c.subscriptionEventConfiguration.Clone() } func (c *pubSubStreamReceiveEventHookContext) NewEvent(data []byte) datasource.MutableStreamEvent { diff --git a/router/core/subscriptions_modules_test.go b/router/core/subscriptions_modules_test.go index da3014c2f3..56aeec55d3 100644 --- a/router/core/subscriptions_modules_test.go +++ b/router/core/subscriptions_modules_test.go @@ -3,6 +3,7 @@ package core import ( "context" "net/http/httptest" + "slices" "testing" "github.com/stretchr/testify/require" @@ -14,6 +15,7 @@ import ( type testSubscriptionEventConfig struct { providerID string fieldName string + channels []string } func (c *testSubscriptionEventConfig) ProviderID() string { @@ -28,6 +30,12 @@ func (c *testSubscriptionEventConfig) RootFieldName() string { return c.fieldName } +func (c *testSubscriptionEventConfig) Clone() datasource.SubscriptionEventConfiguration { + c2 := *c + c2.channels = slices.Clone(c.channels) + return &c2 +} + func TestNewPubSubSubscriptionOnStartHookReturnsUpdatedSubscriptionEventConfiguration(t *testing.T) { originalConfig := &testSubscriptionEventConfig{ providerID: "provider", @@ -39,7 +47,10 @@ func TestNewPubSubSubscriptionOnStartHookReturnsUpdatedSubscriptionEventConfigur } hook := NewPubSubSubscriptionOnStartHook(func(ctx SubscriptionOnStartHandlerContext) error { - require.Same(t, originalConfig, ctx.SubscriptionEventConfiguration()) + got := ctx.SubscriptionEventConfiguration() + // The getter returns a defensive copy: equal by value but not the same pointer. + require.NotSame(t, originalConfig, got) + require.Equal(t, originalConfig, got) require.True(t, ctx.SetSubscriptionEventConfiguration(updatedConfig)) return nil }) @@ -56,3 +67,35 @@ func TestNewPubSubSubscriptionOnStartHookReturnsUpdatedSubscriptionEventConfigur require.NoError(t, err) require.Same(t, updatedConfig, actualConfig) } + +func TestNewPubSubSubscriptionOnStartHookInPlaceMutationIsNoOp(t *testing.T) { + originalConfig := &testSubscriptionEventConfig{ + providerID: "provider", + fieldName: "original", + channels: []string{"original-channel"}, + } + + hook := NewPubSubSubscriptionOnStartHook(func(ctx SubscriptionOnStartHandlerContext) error { + // Mutating the returned config in place must not affect the live + // configuration: it is a defensive copy. Only SetSubscriptionEventConfiguration applies changes. + got := ctx.SubscriptionEventConfiguration().(*testSubscriptionEventConfig) + got.fieldName = "mutated" + got.channels[0] = "mutated-channel" + return nil + }) + + req := httptest.NewRequest("GET", "/graphql", nil) + reqCtx := buildRequestContext(requestContextOptions{r: req}) + ctx := context.WithValue(req.Context(), rcontext.RequestContextKey, reqCtx) + + actualConfig, err := hook(resolve.StartupHookContext{ + Context: ctx, + Updater: func(data []byte) {}, + }, originalConfig, nil) + + require.NoError(t, err) + // Without SetSubscriptionEventConfiguration the original, unmodified config is returned. + require.Same(t, originalConfig, actualConfig) + require.Equal(t, "original", originalConfig.fieldName) + require.Equal(t, []string{"original-channel"}, originalConfig.channels) +} diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go index fd02ffccf6..fd59847965 100644 --- a/router/pkg/pubsub/datasource/provider.go +++ b/router/pkg/pubsub/datasource/provider.go @@ -102,6 +102,10 @@ type SubscriptionEventConfiguration interface { ProviderID() string ProviderType() ProviderType RootFieldName() string // the root field name of the subscription in the schema + // Clone returns a deep copy of the configuration. It is used to hand out + // copies to module hooks so callers cannot mutate the live configuration in + // place; changes are only applied when passed back via SetSubscriptionEventConfiguration. + Clone() SubscriptionEventConfiguration } // PublishEventConfiguration is the interface that all publish event configurations must implement diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index b956ab38f0..ad0020fa17 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -58,6 +58,11 @@ func (c *testSubscriptionConfig) RootFieldName() string { return c.fieldName } +func (c *testSubscriptionConfig) Clone() SubscriptionEventConfiguration { + c2 := *c + return &c2 +} + type testPublishConfig struct { providerID string providerType ProviderType diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index 141d611276..8f80b486f2 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -34,6 +34,10 @@ func (t testSubscriptionEventConfiguration) RootFieldName() string { return "testSubscription" } +func (t testSubscriptionEventConfiguration) Clone() SubscriptionEventConfiguration { + return t +} + type testRequestContextStore struct { mu sync.RWMutex values map[string]any diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 411ab2bd98..92dce40538 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -34,6 +34,11 @@ func (c *testSubscriptionEventConfig) RootFieldName() string { return c.fieldName } +func (c *testSubscriptionEventConfig) Clone() SubscriptionEventConfiguration { + c2 := *c + return &c2 +} + type receivedHooksArgs struct { events []StreamEvent cfg SubscriptionEventConfiguration diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 6606b49015..b5646d9159 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -106,6 +106,13 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } +// Clone returns a deep copy of the configuration. +func (s *SubscriptionEventConfiguration) Clone() datasource.SubscriptionEventConfiguration { + c := *s + c.Topics = slices.Clone(s.Topics) + return &c +} + // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { Provider string `json:"providerId"` diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index 6739c418ba..b03b3d6f2a 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -108,6 +108,17 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } +// Clone returns a deep copy of the configuration. +func (s *SubscriptionEventConfiguration) Clone() datasource.SubscriptionEventConfiguration { + c := *s + c.Subjects = slices.Clone(s.Subjects) + if s.StreamConfiguration != nil { + sc := *s.StreamConfiguration + c.StreamConfiguration = &sc + } + return &c +} + // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { Provider string `json:"providerId"` diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index 55eee83e20..bb171515f2 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -79,6 +79,13 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } +// Clone returns a deep copy of the configuration. +func (s *SubscriptionEventConfiguration) Clone() datasource.SubscriptionEventConfiguration { + c := *s + c.Channels = slices.Clone(s.Channels) + return &c +} + // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { From 195769752ccd5be42946f43c816eba2289d3c2e7 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 4 Jun 2026 17:14:47 +0200 Subject: [PATCH 4/5] chore: revert module.go --- router/cmd/custom/module/module.go | 121 +++++++++++++++++++++++------ 1 file changed, 97 insertions(+), 24 deletions(-) diff --git a/router/cmd/custom/module/module.go b/router/cmd/custom/module/module.go index 788ddc7449..b7ee8e6d45 100644 --- a/router/cmd/custom/module/module.go +++ b/router/cmd/custom/module/module.go @@ -1,46 +1,119 @@ package module import ( + "fmt" + "net/http" + "github.com/wundergraph/cosmo/router/core" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/cosmo/router/pkg/pubsub/redis" + "go.uber.org/zap" ) -func init() { - core.RegisterModule(&CosmoStreamsModule{}) +const myModuleID = "myModule" + +// MyModule is a simple module that has access to the GraphQL operation and add a header to the response +// It demonstrates how to use the different handlers to customize the router. +// It also shows how to use the config file to configure and validate your module config. +// By default, the config file is located at `config.yaml` in the working directory of the router. +type MyModule struct { + // Properties that are set by the config file are automatically populated based on the `mapstructure` tag + // Create a new section under `modules.` in the config file with the same name as your module. + // Don't forget in Go the first letter of a property must be uppercase to be exported + + Value uint64 `mapstructure:"value"` + + Logger *zap.Logger } -type CosmoStreamsModule struct{} +func (m *MyModule) Provision(ctx *core.ModuleContext) error { + // Provision your module here, validate config etc. -func (m *CosmoStreamsModule) Module() core.ModuleInfo { - return core.ModuleInfo{ - ID: "cosmoStreamsModule", - Priority: 1, - New: func() core.Module { - return &CosmoStreamsModule{} - }, + if m.Value == 0 { + ctx.Logger.Error("Value must be greater than 0") + return fmt.Errorf("value must be greater than 0") } -} -func (m *CosmoStreamsModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, nil + // Assign the logger to the module for non-request related logging + m.Logger = ctx.Logger + + return nil } -func (m *CosmoStreamsModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, nil +func (m *MyModule) Cleanup() error { + // Shutdown your module here, close connections etc. + + return nil } -func (m *CosmoStreamsModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHandlerContext) error { - cfg := ctx.SubscriptionEventConfiguration().(*redis.SubscriptionEventConfiguration) - cfg.Channels = []string{"test123"} - //ctx.SetSubscriptionEventConfiguration(cfg) +func (m *MyModule) OnOriginResponse(response *http.Response, ctx core.RequestContext) *http.Response { + // Return a new response or nil if you want to pass it to the next handler + // If you want to modify the response, return a new response + + // Access the custom value set in OnOriginRequest + value := ctx.GetString("myValue") + + fmt.Println("SharedValue", value) + + fmt.Println("OnOriginResponse", response.Request.URL, response.StatusCode) return nil } +func (m *MyModule) OnOriginRequest(request *http.Request, ctx core.RequestContext) (*http.Request, *http.Response) { + // Return the modified request or nil if you want to pass it to the next handler + // Return a new response if you want to abort the request and return a custom response + + // Set a header on all origin requests + request.Header.Set("myHeader", ctx.GetString("myValue")) + + // Set a custom value on the request context. See OnOriginResponse + ctx.Set("myValue", "myValue") + + return request, nil +} + +func (m *MyModule) RouterOnRequest(ctx core.RequestContext, next http.Handler) { + logger := ctx.Logger() + logger.Info("Test RouterOnRequest custom module logs") + + next.ServeHTTP(ctx.ResponseWriter(), ctx.Request()) +} + +func (m *MyModule) Middleware(ctx core.RequestContext, next http.Handler) { + + operation := ctx.Operation() + + logger := ctx.Logger() + logger.Info("Test custom module logs") + // Access the GraphQL operation context + fmt.Println( + operation.Name(), + operation.Type(), + operation.Hash(), + operation.Content(), + ) + + // Call the next handler in the chain or return early by calling w.Write() + next.ServeHTTP(ctx.ResponseWriter(), ctx.Request()) +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + // Interface guard var ( - _ core.StreamPublishEventHandler = (*CosmoStreamsModule)(nil) - _ core.StreamReceiveEventHandler = (*CosmoStreamsModule)(nil) - _ core.SubscriptionOnStartHandler = (*CosmoStreamsModule)(nil) + _ core.RouterMiddlewareHandler = (*MyModule)(nil) + _ core.RouterOnRequestHandler = (*MyModule)(nil) + _ core.EnginePreOriginHandler = (*MyModule)(nil) + _ core.EnginePostOriginHandler = (*MyModule)(nil) + _ core.Provisioner = (*MyModule)(nil) + _ core.Cleaner = (*MyModule)(nil) ) From df429b984f9a26f5e097fc9e47217a64973f9008 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 4 Jun 2026 17:37:53 +0200 Subject: [PATCH 5/5] fix: clone the config object before returning --- router/core/subscriptions_modules.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index 6916853f29..d088ea97ce 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -120,7 +120,7 @@ func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() } // Return a deep copy so callers cannot mutate the live configuration in // place. Changes are only applied when passed back via SetSubscriptionEventConfiguration. - return c.subscriptionEventConfiguration + return c.subscriptionEventConfiguration.Clone() } func (c *pubSubSubscriptionOnStartHookContext) SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool {