Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions router-tests/modules/start_subscription_change_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package module_test

import (
"sync/atomic"
"testing"
"time"

"github.com/hasura/go-graphql-client"
start_subscription "github.com/wundergraph/cosmo/router-tests/modules/start-subscription"
"go.uber.org/zap/zapcore"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
"github.com/wundergraph/cosmo/router/pkg/config"
"github.com/wundergraph/cosmo/router/pkg/pubsub/redis"
)

type subscriptionArgs struct {
dataValue []byte
errValue error
}

func TestStartSubscriptionChangeHook(t *testing.T) {
t.Run("Test StartSubscription hook can change channel", func(t *testing.T) {
t.Parallel()
logicalChannel := "customRedisChannel"
newChannel := ""

customModule := &start_subscription.StartSubscriptionModule{
HookCallCount: &atomic.Int32{},
Callback: func(ctx core.SubscriptionOnStartHandlerContext) error {
redisCfg, ok := ctx.SubscriptionEventConfiguration().(*redis.SubscriptionEventConfiguration)
if ok {
redisCfg.Channels = []string{newChannel}
ctx.SetSubscriptionEventConfiguration(redisCfg)
}
return nil
},
}

cfg := config.Config{
Graph: config.Graph{},
Modules: map[string]interface{}{
"startSubscriptionModule": customModule,
},
}

testenv.Run(t, &testenv.Config{
RouterConfigJSONTemplate: testenv.ConfigWithEdfsRedisJSONTemplate,
EnableRedis: true,
RouterOptions: []core.Option{
core.WithModulesConfig(cfg.Modules),
core.WithCustomModules(&start_subscription.StartSubscriptionModule{}),
},
LogObservation: testenv.LogObservationConfig{
Enabled: true,
LogLevel: zapcore.InfoLevel,
},
}, func(t *testing.T, xEnv *testenv.Environment) {
newChannel = xEnv.GetPubSubName(logicalChannel)

var subscriptionOne struct {
employeeUpdatedMyRedis struct {
ID float64 `graphql:"id"`
Details struct {
Forename string `graphql:"forename"`
Surname string `graphql:"surname"`
} `graphql:"details"`
} `graphql:"employeeUpdatedMyRedis(id: 2)"`
}

surl := xEnv.GraphQLWebSocketSubscriptionURL()
client := graphql.NewSubscriptionClient(surl)

subscriptionArgsCh := make(chan subscriptionArgs)
subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error {
subscriptionArgsCh <- subscriptionArgs{dataValue, errValue}
return nil
})
require.NoError(t, err)
require.NotEmpty(t, subscriptionOneID)

clientRunCh := make(chan error)
go func() {
clientRunCh <- client.Run()
}()

xEnv.WaitForSubscriptionCount(1, time.Second*10)
xEnv.WaitForTriggerCount(1, time.Second*10)

// produce a message (retry until subscription pipeline is confirmed active)
xEnv.RedisPublishUntilReceived(logicalChannel, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`, 10*time.Second)

// The SubscriptionOnStart hook may be called asynchronously after
// WaitForSubscriptionCount returns, so poll until it fires.
require.Eventually(t, func() bool {
return customModule.HookCallCount.Load() >= 1
}, time.Second*10, time.Millisecond*50)

// process the message
select {
case subscriptionArgs := <-subscriptionArgsCh:
require.NoError(t, subscriptionArgs.errValue)
require.JSONEq(t, `{"employeeUpdatedMyRedis":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(subscriptionArgs.dataValue))
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for first message error")
}

require.NoError(t, client.Close())
testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) {
require.NoError(t, err)
}, "unable to close client before timeout")

assert.Equal(t, int32(1), customModule.HookCallCount.Load())
})
})
}
36 changes: 32 additions & 4 deletions router/core/subscriptions_modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ type SubscriptionOnStartHandlerContext interface {
Authentication() authentication.Authentication
// SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription)
SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration
// SetSubscriptionEventConfiguration replaces the event configuration used to start
// a pub/sub subscription. It returns false for engine subscriptions or nil configs.
SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool
// EmitEvent sends an event directly to the subscription stream of the
// currently connected client.
//
Expand Down Expand Up @@ -112,7 +115,20 @@ func (c *pubSubSubscriptionOnStartHookContext) Authentication() authentication.A
}

func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration {
return c.subscriptionEventConfiguration
if c.subscriptionEventConfiguration == nil {
return nil
}
// Return a deep copy so callers cannot mutate the live configuration in
// place. Changes are only applied when passed back via SetSubscriptionEventConfiguration.
return c.subscriptionEventConfiguration.Clone()
}

func (c *pubSubSubscriptionOnStartHookContext) SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool {
if config == nil {
return false
}
c.subscriptionEventConfiguration = config
return true
}

func (c *pubSubSubscriptionOnStartHookContext) EmitEvent(event datasource.StreamEvent) bool {
Expand Down Expand Up @@ -199,6 +215,10 @@ func (c *engineSubscriptionOnStartHookContext) SubscriptionEventConfiguration()
return nil
}

func (c *engineSubscriptionOnStartHookContext) SetSubscriptionEventConfiguration(config datasource.SubscriptionEventConfiguration) bool {
return false
}

type SubscriptionOnStartHandler interface {
// SubscriptionOnStart is called once at subscription start
// The error is propagated to the client.
Expand All @@ -211,7 +231,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont
return nil
}

return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn) error {
return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn) (datasource.SubscriptionEventConfiguration, error) {
requestContext := getRequestContext(resolveCtx.Context)

logger := requestContext.Logger()
Expand All @@ -236,7 +256,11 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont
eventBuilder: eventBuilder,
}

return fn(hookCtx)
if err := fn(hookCtx); err != nil {
return nil, err
}

return hookCtx.subscriptionEventConfiguration, nil
}
}

Expand Down Expand Up @@ -401,7 +425,11 @@ func (c *pubSubStreamReceiveEventHookContext) Authentication() authentication.Au
}

func (c *pubSubStreamReceiveEventHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration {
return c.subscriptionEventConfiguration
if c.subscriptionEventConfiguration == nil {
return nil
}
// Return a deep copy so callers cannot mutate the live configuration in place.
return c.subscriptionEventConfiguration.Clone()
}

func (c *pubSubStreamReceiveEventHookContext) NewEvent(data []byte) datasource.MutableStreamEvent {
Expand Down
101 changes: 101 additions & 0 deletions router/core/subscriptions_modules_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package core

import (
"context"
"net/http/httptest"
"slices"
"testing"

"github.com/stretchr/testify/require"
rcontext "github.com/wundergraph/cosmo/router/internal/context"
"github.com/wundergraph/cosmo/router/pkg/pubsub/datasource"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
)

type testSubscriptionEventConfig struct {
providerID string
fieldName string
channels []string
}

func (c *testSubscriptionEventConfig) ProviderID() string {
return c.providerID
}

func (c *testSubscriptionEventConfig) ProviderType() datasource.ProviderType {
return datasource.ProviderTypeRedis
}

func (c *testSubscriptionEventConfig) RootFieldName() string {
return c.fieldName
}

func (c *testSubscriptionEventConfig) Clone() datasource.SubscriptionEventConfiguration {
c2 := *c
c2.channels = slices.Clone(c.channels)
return &c2
}

func TestNewPubSubSubscriptionOnStartHookReturnsUpdatedSubscriptionEventConfiguration(t *testing.T) {
originalConfig := &testSubscriptionEventConfig{
providerID: "provider",
fieldName: "original",
}
updatedConfig := &testSubscriptionEventConfig{
providerID: "provider",
fieldName: "updated",
}

hook := NewPubSubSubscriptionOnStartHook(func(ctx SubscriptionOnStartHandlerContext) error {
got := ctx.SubscriptionEventConfiguration()
// The getter returns a defensive copy: equal by value but not the same pointer.
require.NotSame(t, originalConfig, got)
require.Equal(t, originalConfig, got)
require.True(t, ctx.SetSubscriptionEventConfiguration(updatedConfig))
return nil
})

req := httptest.NewRequest("GET", "/graphql", nil)
reqCtx := buildRequestContext(requestContextOptions{r: req})
ctx := context.WithValue(req.Context(), rcontext.RequestContextKey, reqCtx)

actualConfig, err := hook(resolve.StartupHookContext{
Context: ctx,
Updater: func(data []byte) {},
}, originalConfig, nil)

require.NoError(t, err)
require.Same(t, updatedConfig, actualConfig)
}

func TestNewPubSubSubscriptionOnStartHookInPlaceMutationIsNoOp(t *testing.T) {
originalConfig := &testSubscriptionEventConfig{
providerID: "provider",
fieldName: "original",
channels: []string{"original-channel"},
}

hook := NewPubSubSubscriptionOnStartHook(func(ctx SubscriptionOnStartHandlerContext) error {
// Mutating the returned config in place must not affect the live
// configuration: it is a defensive copy. Only SetSubscriptionEventConfiguration applies changes.
got := ctx.SubscriptionEventConfiguration().(*testSubscriptionEventConfig)
got.fieldName = "mutated"
got.channels[0] = "mutated-channel"
return nil
})

req := httptest.NewRequest("GET", "/graphql", nil)
reqCtx := buildRequestContext(requestContextOptions{r: req})
ctx := context.WithValue(req.Context(), rcontext.RequestContextKey, reqCtx)

actualConfig, err := hook(resolve.StartupHookContext{
Context: ctx,
Updater: func(data []byte) {},
}, originalConfig, nil)

require.NoError(t, err)
// Without SetSubscriptionEventConfiguration the original, unmodified config is returned.
require.Same(t, originalConfig, actualConfig)
require.Equal(t, "original", originalConfig.fieldName)
require.Equal(t, []string{"original-channel"}, originalConfig.channels)
}
2 changes: 1 addition & 1 deletion router/pkg/pubsub/datasource/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
)

type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error
type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn) (SubscriptionEventConfiguration, error)

type OnPublishEventsFn func(ctx context.Context, pubConf PublishEventConfiguration, evts []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error)

Expand Down
4 changes: 4 additions & 0 deletions router/pkg/pubsub/datasource/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ type SubscriptionEventConfiguration interface {
ProviderID() string
ProviderType() ProviderType
RootFieldName() string // the root field name of the subscription in the schema
// Clone returns a deep copy of the configuration. It is used to hand out
// copies to module hooks so callers cannot mutate the live configuration in
// place; changes are only applied when passed back via SetSubscriptionEventConfiguration.
Clone() SubscriptionEventConfiguration
}

// PublishEventConfiguration is the interface that all publish event configurations must implement
Expand Down
5 changes: 5 additions & 0 deletions router/pkg/pubsub/datasource/pubsubprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ func (c *testSubscriptionConfig) RootFieldName() string {
return c.fieldName
}

func (c *testSubscriptionConfig) Clone() SubscriptionEventConfiguration {
c2 := *c
return &c2
}

type testPublishConfig struct {
providerID string
providerType ProviderType
Expand Down
Loading
Loading