diff --git a/router-tests/modules/start_subscription_change_test.go b/router-tests/modules/start_subscription_change_test.go new file mode 100644 index 0000000000..e5acf220eb --- /dev/null +++ b/router-tests/modules/start_subscription_change_test.go @@ -0,0 +1,119 @@ +package module_test + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + start_subscription "github.com/wundergraph/cosmo/router-tests/modules/start-subscription" + "go.uber.org/zap/zapcore" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/redis" +) + +type subscriptionArgs struct { + dataValue []byte + errValue error +} + +func TestStartSubscriptionChangeHook(t *testing.T) { + t.Run("Test StartSubscription hook can change channel", func(t *testing.T) { + t.Parallel() + logicalChannel := "customRedisChannel" + newChannel := "" + + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + redisCfg, ok := ctx.SubscriptionEventConfiguration().(*redis.SubscriptionEventConfiguration) + if ok { + redisCfg.Channels = []string{newChannel} + ctx.SetSubscriptionEventConfiguration(redisCfg) + } + return nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": customModule, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsRedisJSONTemplate, + EnableRedis: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + newChannel = xEnv.GetPubSubName(logicalChannel) + + var subscriptionOne struct { + employeeUpdatedMyRedis struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyRedis(id: 2)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan subscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- subscriptionArgs{dataValue, errValue} + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForTriggerCount(1, time.Second*10) + + // produce a message (retry until subscription pipeline is confirmed active) + xEnv.RedisPublishUntilReceived(logicalChannel, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`, 10*time.Second) + + // The SubscriptionOnStart hook may be called asynchronously after + // WaitForSubscriptionCount returns, so poll until it fires. + require.Eventually(t, func() bool { + return customModule.HookCallCount.Load() >= 1 + }, time.Second*10, time.Millisecond*50) + + // process the message + select { + case subscriptionArgs := <-subscriptionArgsCh: + require.NoError(t, subscriptionArgs.errValue) + require.JSONEq(t, `{"employeeUpdatedMyRedis":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(subscriptionArgs.dataValue)) + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for first message error") + } + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) + }) + }) +} diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index c01d1fa348..d088ea97ce 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -23,6 +23,9 @@ type SubscriptionOnStartHandlerContext interface { Authentication() authentication.Authentication // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // SetSubscriptionEventConfiguration replaces the event configuration used to start + // a pub/sub subscription. It returns false for engine subscriptions or nil configs. + SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool // EmitEvent sends an event directly to the subscription stream of the // currently connected client. // @@ -112,7 +115,20 @@ func (c *pubSubSubscriptionOnStartHookContext) Authentication() authentication.A } func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { - return c.subscriptionEventConfiguration + if c.subscriptionEventConfiguration == nil { + return nil + } + // Return a deep copy so callers cannot mutate the live configuration in + // place. Changes are only applied when passed back via SetSubscriptionEventConfiguration. + return c.subscriptionEventConfiguration.Clone() +} + +func (c *pubSubSubscriptionOnStartHookContext) SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool { + if config == nil { + return false + } + c.subscriptionEventConfiguration = config + return true } func (c *pubSubSubscriptionOnStartHookContext) EmitEvent(event datasource.StreamEvent) bool { @@ -199,6 +215,10 @@ func (c *engineSubscriptionOnStartHookContext) SubscriptionEventConfiguration() return nil } +func (c *engineSubscriptionOnStartHookContext) SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool { + return false +} + type SubscriptionOnStartHandler interface { // SubscriptionOnStart is called once at subscription start // The error is propagated to the client. @@ -211,7 +231,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont return nil } - return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn) error { + return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn) (datasource.SubscriptionEventConfiguration, error) { requestContext := getRequestContext(resolveCtx.Context) logger := requestContext.Logger() @@ -236,7 +256,11 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont eventBuilder: eventBuilder, } - return fn(hookCtx) + if err := fn(hookCtx); err != nil { + return nil, err + } + + return hookCtx.subscriptionEventConfiguration, nil } } @@ -401,7 +425,11 @@ func (c *pubSubStreamReceiveEventHookContext) Authentication() authentication.Au } func (c *pubSubStreamReceiveEventHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { - return c.subscriptionEventConfiguration + if c.subscriptionEventConfiguration == nil { + return nil + } + // Return a deep copy so callers cannot mutate the live configuration in place. + return c.subscriptionEventConfiguration.Clone() } func (c *pubSubStreamReceiveEventHookContext) NewEvent(data []byte) datasource.MutableStreamEvent { diff --git a/router/core/subscriptions_modules_test.go b/router/core/subscriptions_modules_test.go new file mode 100644 index 0000000000..56aeec55d3 --- /dev/null +++ b/router/core/subscriptions_modules_test.go @@ -0,0 +1,101 @@ +package core + +import ( + "context" + "net/http/httptest" + "slices" + "testing" + + "github.com/stretchr/testify/require" + rcontext "github.com/wundergraph/cosmo/router/internal/context" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type testSubscriptionEventConfig struct { + providerID string + fieldName string + channels []string +} + +func (c *testSubscriptionEventConfig) ProviderID() string { + return c.providerID +} + +func (c *testSubscriptionEventConfig) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeRedis +} + +func (c *testSubscriptionEventConfig) RootFieldName() string { + return c.fieldName +} + +func (c *testSubscriptionEventConfig) Clone() datasource.SubscriptionEventConfiguration { + c2 := *c + c2.channels = slices.Clone(c.channels) + return &c2 +} + +func TestNewPubSubSubscriptionOnStartHookReturnsUpdatedSubscriptionEventConfiguration(t *testing.T) { + originalConfig := &testSubscriptionEventConfig{ + providerID: "provider", + fieldName: "original", + } + updatedConfig := &testSubscriptionEventConfig{ + providerID: "provider", + fieldName: "updated", + } + + hook := NewPubSubSubscriptionOnStartHook(func(ctx SubscriptionOnStartHandlerContext) error { + got := ctx.SubscriptionEventConfiguration() + // The getter returns a defensive copy: equal by value but not the same pointer. + require.NotSame(t, originalConfig, got) + require.Equal(t, originalConfig, got) + require.True(t, ctx.SetSubscriptionEventConfiguration(updatedConfig)) + return nil + }) + + req := httptest.NewRequest("GET", "/graphql", nil) + reqCtx := buildRequestContext(requestContextOptions{r: req}) + ctx := context.WithValue(req.Context(), rcontext.RequestContextKey, reqCtx) + + actualConfig, err := hook(resolve.StartupHookContext{ + Context: ctx, + Updater: func(data []byte) {}, + }, originalConfig, nil) + + require.NoError(t, err) + require.Same(t, updatedConfig, actualConfig) +} + +func TestNewPubSubSubscriptionOnStartHookInPlaceMutationIsNoOp(t *testing.T) { + originalConfig := &testSubscriptionEventConfig{ + providerID: "provider", + fieldName: "original", + channels: []string{"original-channel"}, + } + + hook := NewPubSubSubscriptionOnStartHook(func(ctx SubscriptionOnStartHandlerContext) error { + // Mutating the returned config in place must not affect the live + // configuration: it is a defensive copy. Only SetSubscriptionEventConfiguration applies changes. + got := ctx.SubscriptionEventConfiguration().(*testSubscriptionEventConfig) + got.fieldName = "mutated" + got.channels[0] = "mutated-channel" + return nil + }) + + req := httptest.NewRequest("GET", "/graphql", nil) + reqCtx := buildRequestContext(requestContextOptions{r: req}) + ctx := context.WithValue(req.Context(), rcontext.RequestContextKey, reqCtx) + + actualConfig, err := hook(resolve.StartupHookContext{ + Context: ctx, + Updater: func(data []byte) {}, + }, originalConfig, nil) + + require.NoError(t, err) + // Without SetSubscriptionEventConfiguration the original, unmodified config is returned. + require.Same(t, originalConfig, actualConfig) + require.Equal(t, "original", originalConfig.fieldName) + require.Equal(t, []string{"original-channel"}, originalConfig.channels) +} diff --git a/router/pkg/pubsub/datasource/hooks.go b/router/pkg/pubsub/datasource/hooks.go index a262058463..545e97e8fe 100644 --- a/router/pkg/pubsub/datasource/hooks.go +++ b/router/pkg/pubsub/datasource/hooks.go @@ -7,7 +7,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error +type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) type OnPublishEventsFn func(ctx context.Context, pubConf PublishEventConfiguration, evts []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go index fd02ffccf6..fd59847965 100644 --- a/router/pkg/pubsub/datasource/provider.go +++ b/router/pkg/pubsub/datasource/provider.go @@ -102,6 +102,10 @@ type SubscriptionEventConfiguration interface { ProviderID() string ProviderType() ProviderType RootFieldName() string // the root field name of the subscription in the schema + // Clone returns a deep copy of the configuration. It is used to hand out + // copies to module hooks so callers cannot mutate the live configuration in + // place; changes are only applied when passed back via SetSubscriptionEventConfiguration. + Clone() SubscriptionEventConfiguration } // PublishEventConfiguration is the interface that all publish event configurations must implement diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index b956ab38f0..ad0020fa17 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -58,6 +58,11 @@ func (c *testSubscriptionConfig) RootFieldName() string { return c.fieldName } +func (c *testSubscriptionConfig) Clone() SubscriptionEventConfiguration { + c2 := *c + return &c2 +} + type testPublishConfig struct { providerID string providerType ProviderType diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index 939c03f94d..3a108aef4e 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -1,12 +1,15 @@ package datasource import ( + "context" "encoding/json" "errors" "fmt" "net/http" + "strconv" "github.com/cespare/xxhash/v2" + rcontext "github.com/wundergraph/cosmo/router/internal/context" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -16,6 +19,13 @@ type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Dige type EventBuilderFn func(data []byte) MutableStreamEvent +const subscriptionEventConfigurationContextKeyPrefix = "wg.cosmo.pubsub.subscription_event_configuration." + +type subscriptionEventConfigurationStore interface { + Set(key string, value any) + Get(key string) (value any, exists bool) +} + // PubSubSubscriptionDataSource is a data source for handling subscriptions using a Pub/Sub mechanism. // It implements the SubscriptionDataSource interface and HookableSubscriptionDataSource type PubSubSubscriptionDataSource[C SubscriptionEventConfiguration] struct { @@ -32,8 +42,15 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionEventConfiguration(input [ return subscriptionConfiguration, err } +func (s *PubSubSubscriptionDataSource[C]) subscriptionEventConfiguration(ctx context.Context, input []byte) (SubscriptionEventConfiguration, error) { + if conf, ok := subscriptionEventConfigurationFromContext(ctx, input); ok { + return conf, nil + } + return s.SubscriptionEventConfiguration(input) +} + func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, header http.Header, input []byte, updater resolve.SubscriptionUpdater) error { - subConf, err := s.SubscriptionEventConfiguration(input) + subConf, err := s.subscriptionEventConfiguration(ctx.Context(), input) if err != nil { return err } @@ -70,17 +87,26 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.Startu } }() + if len(s.hooks.SubscriptionOnStart.Handlers) == 0 { + return nil + } + + conf, err := s.SubscriptionEventConfiguration(input) + if err != nil { + return err + } + for _, fn := range s.hooks.SubscriptionOnStart.Handlers { - conf, err := s.SubscriptionEventConfiguration(input) + conf, err = fn(ctx, conf, s.eventBuilder) if err != nil { return err } - err = fn(ctx, conf, s.eventBuilder) - if err != nil { - return err + if conf == nil { + return errors.New("invalid subscription configuration") } } + setSubscriptionEventConfiguration(ctx.Context, input, conf) return nil } @@ -102,3 +128,40 @@ func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Ad eventBuilder: eventBuilder, } } + +func subscriptionEventConfigurationContextKey(input []byte) string { + return subscriptionEventConfigurationContextKeyPrefix + + strconv.Itoa(len(input)) + ":" + + strconv.FormatUint(xxhash.Sum64(input), 16) +} + +func requestContextStore(ctx context.Context) subscriptionEventConfigurationStore { + if ctx == nil { + return nil + } + store, _ := ctx.Value(rcontext.RequestContextKey).(subscriptionEventConfigurationStore) + return store +} + +func setSubscriptionEventConfiguration(ctx context.Context, input []byte, conf SubscriptionEventConfiguration) { + store := requestContextStore(ctx) + if store == nil { + return + } + store.Set(subscriptionEventConfigurationContextKey(input), conf) +} + +func subscriptionEventConfigurationFromContext(ctx context.Context, input []byte) (SubscriptionEventConfiguration, bool) { + store := requestContextStore(ctx) + if store == nil { + return nil, false + } + + value, ok := store.Get(subscriptionEventConfigurationContextKey(input)) + if !ok || value == nil { + return nil, false + } + + conf, ok := value.(SubscriptionEventConfiguration) + return conf, ok +} diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index 2662a2a2ca..8f80b486f2 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -4,12 +4,14 @@ import ( "context" "encoding/json" "errors" + "sync" "testing" "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + rcontext "github.com/wundergraph/cosmo/router/internal/context" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -32,6 +34,31 @@ func (t testSubscriptionEventConfiguration) RootFieldName() string { return "testSubscription" } +func (t testSubscriptionEventConfiguration) Clone() SubscriptionEventConfiguration { + return t +} + +type testRequestContextStore struct { + mu sync.RWMutex + values map[string]any +} + +func (s *testRequestContextStore) Set(key string, value any) { + s.mu.Lock() + defer s.mu.Unlock() + if s.values == nil { + s.values = make(map[string]any) + } + s.values[key] = value +} + +func (s *testRequestContextStore) Get(key string) (value any, exists bool) { + s.mu.RLock() + defer s.mu.RUnlock() + value, exists = s.values[key] + return value, exists +} + // testSubscriptionDataSourceEventBuilder is a reusable event builder for tests func testSubscriptionDataSourceEventBuilder(data []byte) MutableStreamEvent { return mutableTestEvent(data) @@ -101,6 +128,57 @@ func TestPubSubSubscriptionDataSource_Start_Success(t *testing.T) { mockAdapter.AssertExpectations(t) } +func TestPubSubSubscriptionDataSource_Start_UsesSubscriptionOnStartConfiguration(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[*testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) + + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + typedConfig, ok := config.(*testSubscriptionEventConfiguration) + require.True(t, ok) + typedConfig.Topic = "changed-topic" + return typedConfig, nil + } + + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook}, + }, + }) + + testConfig := &testSubscriptionEventConfiguration{ + Topic: "original-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + require.NoError(t, err) + + requestStore := &testRequestContextStore{} + requestCtx := context.WithValue(context.Background(), rcontext.RequestContextKey, requestStore) + + err = dataSource.SubscriptionOnStart(resolve.StartupHookContext{ + Context: requestCtx, + Updater: func(data []byte) {}, + }, input) + require.NoError(t, err) + + ctx := resolve.NewContext(requestCtx) + mockUpdater := NewMockSubscriptionUpdater(t) + expectedConfig := &testSubscriptionEventConfiguration{ + Topic: "changed-topic", + Subject: "test-subject", + } + + mockAdapter.On("Subscribe", ctx.Context(), expectedConfig, mock.AnythingOfType("*datasource.subscriptionEventUpdater")).Return(nil) + + err = dataSource.Start(ctx, nil, input, mockUpdater) + require.NoError(t, err) + mockAdapter.AssertExpectations(t) +} + func TestPubSubSubscriptionDataSource_Start_NoConfiguration(t *testing.T) { mockAdapter := NewMockProvider(t) uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { @@ -183,20 +261,20 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T hook1EventBuilderExists := false hook2EventBuilderExists := false - hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { hook1Called = true if eventBuilder != nil { hook1EventBuilderExists = true } - return nil + return config, nil } - hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { hook2Called = true if eventBuilder != nil { hook2EventBuilderExists = true } - return nil + return config, nil } dataSource.SetHooks(Hooks{ @@ -234,8 +312,8 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsClose(t *te dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Add hook that returns close=true - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { - return nil + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + return config, nil } dataSource.SetHooks(Hooks{ @@ -270,8 +348,8 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *te expectedError := errors.New("hook error") // Add hook that returns an error - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { - return expectedError + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + return nil, expectedError } dataSource.SetHooks(Hooks{ @@ -309,11 +387,11 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { assert.Len(t, dataSource.hooks.SubscriptionOnStart.Handlers, 0) // Add hooks - hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { - return nil + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + return config, nil } - hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { - return nil + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { + return config, nil } dataSource.SetHooks(Hooks{ @@ -369,9 +447,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_InvalidEventConfigInpu dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) hookCalled := false - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { hookCalled = true - return nil + return config, nil } dataSource.SetHooks(Hooks{ @@ -430,7 +508,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_PanicRecovery(t *testi dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Add hook that panics - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error) { panic(tt.panicValue) } diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 411ab2bd98..92dce40538 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -34,6 +34,11 @@ func (c *testSubscriptionEventConfig) RootFieldName() string { return c.fieldName } +func (c *testSubscriptionEventConfig) Clone() SubscriptionEventConfiguration { + c2 := *c + return &c2 +} + type receivedHooksArgs struct { events []StreamEvent cfg SubscriptionEventConfiguration diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 6606b49015..b5646d9159 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -106,6 +106,13 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } +// Clone returns a deep copy of the configuration. +func (s *SubscriptionEventConfiguration) Clone() datasource.SubscriptionEventConfiguration { + c := *s + c.Topics = slices.Clone(s.Topics) + return &c +} + // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { Provider string `json:"providerId"` diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index 6739c418ba..b03b3d6f2a 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -108,6 +108,17 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } +// Clone returns a deep copy of the configuration. +func (s *SubscriptionEventConfiguration) Clone() datasource.SubscriptionEventConfiguration { + c := *s + c.Subjects = slices.Clone(s.Subjects) + if s.StreamConfiguration != nil { + sc := *s.StreamConfiguration + c.StreamConfiguration = &sc + } + return &c +} + // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { Provider string `json:"providerId"` diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index 55eee83e20..bb171515f2 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -79,6 +79,13 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } +// Clone returns a deep copy of the configuration. +func (s *SubscriptionEventConfiguration) Clone() datasource.SubscriptionEventConfiguration { + c := *s + c.Channels = slices.Clone(s.Channels) + return &c +} + // publishData is a private type that is used to pass data from the engine to the provider type publishData struct {